rombodawg commited on
Commit
6df35ce
·
verified ·
1 Parent(s): 1aedf4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -4,7 +4,8 @@ from threading import Thread
4
  from transformers import StoppingCriteria, StoppingCriteriaList
5
  import torch
6
  import spaces
7
- import os
 
8
  model_name = "microsoft/Phi-3-medium-128k-instruct"
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, trust_remote_code=True)
@@ -17,22 +18,22 @@ class StopOnTokens(StoppingCriteria):
17
  if input_ids[0][-1] == stop_id:
18
  return True
19
  return False
20
- @spaces.GPU()
21
- def predict(message, history):
 
22
  history_transformer_format = history + [[message, ""]]
23
  stop = StopOnTokens()
24
  messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
25
- #messages = "".join(["".join(["<user>"+item[0], "<output>"+item[1]]) for item in history_transformer_format])
26
  model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
27
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
28
  generate_kwargs = dict(
29
  model_inputs,
30
  streamer=streamer,
31
- max_new_tokens=8192,
32
  do_sample=True,
33
- top_p=0.8,
34
- top_k=40,
35
- temperature=0.9,
36
  stopping_criteria=StoppingCriteriaList([stop])
37
  )
38
  t = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -43,5 +44,14 @@ def predict(message, history):
43
  partial_message += new_token
44
  yield partial_message
45
 
46
- demo = gr.ChatInterface(fn=predict, examples=["Write me a python snake game code", "Write me a ping pong game code"], title="Phi-3-medium-128k-instruct")
47
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
4
  from transformers import StoppingCriteria, StoppingCriteriaList
5
  import torch
6
  import spaces
7
+ import os
8
+
9
  model_name = "microsoft/Phi-3-medium-128k-instruct"
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map='cuda', torch_dtype=torch.float16, trust_remote_code=True)
 
18
  if input_ids[0][-1] == stop_id:
19
  return True
20
  return False
21
+
22
+ @spaces.GPU(duration=120)
23
+ def predict(message, history, temperature, max_tokens, top_p, top_k):
24
  history_transformer_format = history + [[message, ""]]
25
  stop = StopOnTokens()
26
  messages = "".join(["".join(["\n<|end|>\n<|user|>\n"+item[0], "\n<|end|>\n<|assistant|>\n"+item[1]]) for item in history_transformer_format])
 
27
  model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
28
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
29
  generate_kwargs = dict(
30
  model_inputs,
31
  streamer=streamer,
32
+ max_new_tokens=max_tokens,
33
  do_sample=True,
34
+ top_p=top_p,
35
+ top_k=top_k,
36
+ temperature=temperature,
37
  stopping_criteria=StoppingCriteriaList([stop])
38
  )
39
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
44
  partial_message += new_token
45
  yield partial_message
46
 
47
+ demo = gr.ChatInterface(
48
+ fn=predict,
49
+ title="Phi-3-medium-128k-instruct",
50
+ additional_inputs=[
51
+ gr.Slider(0.1, 0.9, value=0.7, label="Temperature"),
52
+ gr.Slider(512, 8192, value=4096, label="Max Tokens"),
53
+ gr.Slider(0.1, 0.9, value=0.7, label="top_p"),
54
+ gr.Slider(10, 90, value=40, label="top_k"),
55
+ ]
56
+ )
57
+ demo.launch(share=True)