File size: 1,609 Bytes
1a8c724
4d3d295
 
 
9e572dc
4d3d295
9c6a874
 
 
 
 
1a8c724
 
4d3d295
 
 
9820b04
 
4d3d295
 
 
 
 
 
 
 
1a8c724
9c6a874
1a8c724
e9e5a55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("./checkpoint-15000/")

# 添加示例
examples = [
    ["猪笼草原产于热带和亚热带地区 现主要分布在东南亚一带 中国广东 广西等地有分布 猪笼草喜欢湿润和温暖半阴的生长环境 不耐寒 怕积水 怕强光 怕干燥 喜欢疏松 肥沃和透气的腐叶土和泥炭土 对光照要求较为严格 猪笼草的繁殖方式包括扦插繁殖 压条繁殖和播种繁殖"],

]

def text_processing(text):
    inputs = [text]

    # Tokenize and prepare the inputs for model
    input_ids = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").input_ids
    attention_mask = tokenizer(inputs, return_tensors="pt", max_length=512, truncation=True, padding="max_length").attention_mask

    # Generate prediction
    output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=512)

    # Decode the prediction
    decoded_output = [tokenizer.decode(ids, skip_special_tokens=True) for ids in output]

    return decoded_output[0]

iface = gr.Interface(fn = text_processing, inputs='text', outputs='text', title='Punctuation Mark Prediction', description='本模型主要用于语言识别模型输出的后处理。\n输入无符号句子,需要打标点处用空格隔开,返回带标点句子。\n仅支持中文,因为训练数据中只有中文。', examples=examples)

iface.launch(inline=False)