yxccai commited on
Commit
1d5b072
·
verified ·
1 Parent(s): c55e334

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -10
app.py CHANGED
@@ -1,10 +1,136 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
-
4
- pipe = pipeline("text2text-generation", model="yxccai/text-style")
5
-
6
- def chat(input_text):
7
- result = pipe(input_text, max_new_tokens=512)[0]["generated_text"]
8
- return result
9
-
10
- gr.Interface(fn=chat, inputs="text", outputs="text").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+ import torch
5
+
6
+ # --- 配置 ---
7
+ # 您上传到Hub的仓库ID (基础模型 + LoRA适配器)
8
+ hub_repo_id = "yxccai/text-style"
9
+ # Qwen模型的基础模型名称 (与您微调时使用的基础模型一致)
10
+ # 例如: "Qwen/Qwen1.5-1.8B-Chat" "Qwen/Qwen1.5-0.5B-Chat"
11
+ # 这个信息通常在您的LoRA适配器配置文件 (adapter_config.json) 中的 base_model_name_or_path 字段
12
+ # 您需要在这里明确指定它,因为我们要先加载基础模型
13
+ base_model_name = "Qwen/Qwen1.5-1.8B-Chat" # 假设您微调的是1.8B版本,请根据实际情况修改
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(f"Gradio App: Using device: {device}")
17
+
18
+ # --- 加载模型和Tokenizer ---
19
+ print(f"Gradio App: Loading base model: {base_model_name}")
20
+ # 1. 加载基础模型
21
+ base_model = AutoModelForCausalLM.from_pretrained(
22
+ base_model_name,
23
+ torch_dtype="auto", # 或者 torch.float16, torch.bfloat16
24
+ # device_map="auto", # 在Spaces中,直接 .to(device) 可能更稳定
25
+ trust_remote_code=True
26
+ # quantization_config=... # 如果基础模型加载时需要量化,这里也要配置
27
+ )
28
+ base_model.to(device)
29
+
30
+ print(f"Gradio App: Loading tokenizer from: {hub_repo_id}")
31
+ # 2. 加载Tokenizer (从您上传的仓库,它应该包含了基础模型的tokenizer配置)
32
+ tokenizer = AutoTokenizer.from_pretrained(hub_repo_id, trust_remote_code=True)
33
+ if tokenizer.pad_token is None:
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+ base_model.config.pad_token_id = tokenizer.eos_token_id
36
+
37
+
38
+ print(f"Gradio App: Loading LoRA adapter from: {hub_repo_id}")
39
+ # 3. 加载并应用LoRA适配器
40
+ # hub_repo_id 指向的是包含LoRA适配器权重 (adapter_model.bin) 和配置 (adapter_config.json) 的仓库
41
+ model = PeftModel.from_pretrained(base_model, hub_repo_id)
42
+ # model.to(device) # base_model 已经 to(device) 了,PeftModel会继承
43
+
44
+ # (可选) 如果希望合并权重以简化,但会占用更多内存/磁盘
45
+ # print("Gradio App: Merging LoRA adapter...")
46
+ # model = model.merge_and_unload()
47
+ # print("Gradio App: LoRA adapter merged.")
48
+
49
+ model.eval() # 设置为评估模式
50
+ print("Gradio App: Model and tokenizer loaded successfully.")
51
+
52
+ # --- 推理函数 ---
53
+ def chat(input_text):
54
+ print(f"Gradio App: Received input: {input_text}")
55
+ # 构建符合Qwen Chat模板的输入
56
+ messages = [
57
+ {"role": "system", "content": "你是一个文本风格转换助手。请严格按照要求,仅将以下书面文本转换为自然、口语化的简洁表达方式,不要添加任何额外的解释、扩展信息或重复原文。"},
58
+ {"role": "user", "content": input_text}
59
+ ]
60
+
61
+ # 使用 apply_chat_template
62
+ # 注意:Hugging Face Spaces环境中的transformers版本可能与Colab不同
63
+ # 确保 apply_chat_template 的用法与您测试时一致
64
+ try:
65
+ prompt = tokenizer.apply_chat_template(
66
+ messages,
67
+ tokenize=False,
68
+ add_generation_prompt=True # 推理时需要模型知道何时开始生成
69
+ )
70
+ except Exception as e:
71
+ print(f"Error applying chat template: {e}")
72
+ # 回退到一个简单的拼接方式,但这可能不是最优的
73
+ prompt = messages[0]["content"] + "\n" + messages[1]["content"] + "\n" + tokenizer.eos_token # 或者其他适合的格式
74
+
75
+
76
+ print(f"Gradio App: Formatted prompt for model:\n{prompt}")
77
+
78
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device) # 调整max_length
79
+
80
+ generated_ids = model.generate(
81
+ **inputs,
82
+ max_new_tokens=2048, # 控制输出长度
83
+ num_beams=1, # 可以尝试增加
84
+ do_sample=True,
85
+ temperature=0.7,
86
+ top_k=50,
87
+ top_p=0.95,
88
+ pad_token_id=tokenizer.eos_token_id
89
+ )
90
+
91
+ # 解码生成的token IDs
92
+ # generated_ids[0] 包含了输入提示和模型生成的部分
93
+ full_generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=False) # 保留特殊token以帮助分割
94
+ print(f"Gradio App: Full generated sequence:\n{full_generated_text}")
95
+
96
+ # 从完整序列中提取assistant的回复
97
+ assistant_marker_start = "<|im_start|>assistant" # Qwen的标记
98
+
99
+ if assistant_marker_start in full_generated_text:
100
+ parts = full_generated_text.split(assistant_marker_start)
101
+ if len(parts) > 1:
102
+ assistant_reply = parts[-1].strip()
103
+ # 移除可能的结束标记,如 <|im_end|> 或 eos_token
104
+ if assistant_reply.endswith(tokenizer.eos_token):
105
+ assistant_reply = assistant_reply[:-len(tokenizer.eos_token)].strip()
106
+ elif "<|im_end|>" in assistant_reply: # Qwen的聊天模板使用 <|im_end|>
107
+ assistant_reply = assistant_reply.split("<|im_end|>")[0].strip()
108
+ result = assistant_reply
109
+ else:
110
+ result = "模型未能生成assistant标记后的回复。"
111
+ else:
112
+ # 如果找不到 assistant 标记,尝试从原始prompt之后提取
113
+ # 这需要原始prompt的token数量
114
+ # 另一种简单方式是直接解码去除特殊token的生成部分,但这可能包含一些模板残留
115
+ result = tokenizer.decode(generated_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True).strip()
116
+ if not result: # 如果这种方式结果为空,可能解码时skip_special_tokens去除了所有
117
+ result = "模型输出格式不符合预期,未能提取有效回复。"
118
+
119
+ print(f"Gradio App: Extracted result: {result}")
120
+ return result
121
+
122
+ # --- 创建Gradio界面 ---
123
+ iface = gr.Interface(
124
+ fn=chat,
125
+ inputs=gr.Textbox(lines=5, label="输入书面文本 (Input Formal Text)"),
126
+ outputs=gr.Textbox(lines=5, label="输出口语化文本 (Output Casual Text)"),
127
+ title="文本风格转换器 (Text Style Converter)",
128
+ description="输入一段书面化的中文文本,模型会尝试将其转换为更自然、口语化的表达方式。由Qwen模型微调。",
129
+ examples=[
130
+ ["乙醇的检测方法包括以下几项: 1. 酸碱度检查:取20ml乙醇加20ml水,加2滴酚酞指示剂应无色,再加1ml 0.01mol/L氢氧化钠应显粉红色."],
131
+ ["本公司今日发布了最新的财务业绩报告,数据显示本季度利润实现了显著增长。"]
132
+ ]
133
+ )
134
+
135
+ if __name__ == "__main__":
136
+ iface.launch()