ykallan commited on
Commit
8a6a16d
·
verified ·
1 Parent(s): 6459c4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -41
app.py CHANGED
@@ -1,66 +1,60 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
6
  """
7
 
8
  pretrained_model = "ykallan/SkuInfo-Qwen2.5-3B-Instruct"
9
 
10
- client = InferenceClient(pretrained_model)
 
 
 
11
 
12
  def respond(
13
- message,
14
  history: list[tuple[str, str]],
15
  system_message,
16
  max_tokens,
17
  temperature,
18
  top_p,
19
  ):
20
- messages = [{"role": "system", "content": system_message}]
21
-
22
- for val in history:
23
- if val[0]:
24
- messages.append({"role": "user", "content": val[0]})
25
- if val[1]:
26
- messages.append({"role": "assistant", "content": val[1]})
27
 
28
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
29
 
30
- response = ""
31
 
32
- for message in client.chat_completion(
33
- messages,
34
- max_tokens=max_tokens,
35
- stream=True,
36
- temperature=temperature,
37
- top_p=top_p,
38
- ):
39
- token = message.choices[0].delta.content
40
 
41
- response += token
42
- yield response
 
43
 
 
44
 
45
- """
46
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
47
- """
48
- demo = gr.ChatInterface(
49
- respond,
50
- additional_inputs=[
51
- gr.Textbox(value="在以下商品名称中抽取出品牌、型号、主商品,并以JSON格式返回。", label="System message"),
52
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
53
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
54
- gr.Slider(
55
- minimum=0.1,
56
- maximum=1.0,
57
- value=0.95,
58
- step=0.05,
59
- label="Top-p (nucleus sampling)",
60
- ),
61
- ],
62
- )
63
 
 
 
64
 
65
- if __name__ == "__main__":
 
 
66
  demo.launch()
 
1
  import gradio as gr
 
2
 
3
+ from transformers import AutoModel, AutoTokenizer
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
+ https://huggingface.co/spaces/jclian91/Chinese_Late_Chunking/blob/main/app.py
7
+
8
+
9
  """
10
 
11
  pretrained_model = "ykallan/SkuInfo-Qwen2.5-3B-Instruct"
12
 
13
+ # load model and tokenizer
14
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model, trust_remote_code=True)
15
+ model = AutoModel.from_pretrained(pretrained_model, trust_remote_code=True)
16
+
17
 
18
  def respond(
19
+ sku_name,
20
  history: list[tuple[str, str]],
21
  system_message,
22
  max_tokens,
23
  temperature,
24
  top_p,
25
  ):
26
+ messages = [
27
+ {"role": "system", "content": "在以下商品名称中抽取出品牌、型号、主商品,并以JSON格式返回。"},
28
+ {"role": "user", "content": sku_name}
29
+ ]
30
+ input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
31
 
32
+ model_inputs = tokenizer([input_ids], return_tensors="pt", padding=True).to(device)
33
+ generate_config = {
34
+ "max_new_tokens": 128
35
+ }
36
+ generated_ids = model.generate(model_inputs.input_ids, **generate_config)
37
+ generated_ids = [
38
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
39
+ ]
40
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
41
+ return response
42
 
 
43
 
 
 
 
 
 
 
 
 
44
 
45
+ if __name__ == "__main__":
46
+ with gr.Blocks() as demo:
47
+ sku_name = gr.TextArea(lines=1, placeholder="your query", label="skuName")
48
 
49
+ submit = gr.Button("Submit")
50
 
51
+
52
+ examples = gr.Examples(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ inputs=[sku_name]
55
+ )
56
 
57
+ submit.click(fn=response,
58
+ inputs=[sku_name],
59
+ )
60
  demo.launch()