Walmart-the-bag commited on
Commit
c04565c
·
verified ·
1 Parent(s): c18201d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import TextIteratorStreamer
3
+ from threading import Thread
4
+ from transformers import StoppingCriteria, StoppingCriteriaList
5
+ import torch
6
+ import os
7
+ from unsloth import FastLanguageModel
8
+ model_name = "microsoft/Phi-3-medium-128k-instruct"
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ model, tokenizer = FastLanguageModel.from_pretrained(model_name, device_map='cuda')
11
+
12
+
13
+ class StopOnTokens(StoppingCriteria):
14
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
15
+ stop_ids = [29, 0]
16
+ for stop_id in stop_ids:
17
+ if input_ids[0][-1] == stop_id:
18
+ return True
19
+ return False
20
+ def predict(message, history):
21
+ history_transformer_format = history + [[message, ""]]
22
+ stop = StopOnTokens()
23
+ messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
24
+ #messages = "".join(["".join(["<user>"+item[0], "<output>"+item[1]]) for item in history_transformer_format])
25
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
26
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
27
+ generate_kwargs = dict(
28
+ model_inputs,
29
+ streamer=streamer,
30
+ max_new_tokens=4096,
31
+ do_sample=True,
32
+ top_p=0.8,
33
+ top_k=40,
34
+ temperature=0.9,
35
+ stopping_criteria=StoppingCriteriaList([stop])
36
+ )
37
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
38
+ t.start()
39
+ partial_message = ""
40
+ for new_token in streamer:
41
+ if new_token != '<':
42
+ partial_message += new_token
43
+ yield partial_message
44
+
45
+ demo = gr.ChatInterface(fn=predict, examples=["Write me a python snake game code", "Write me a ping pong game code"], title="Phi-3-medium-128k-instruct")
46
+ demo.launch(share=True)