hannayukhymenko HF Staff commited on
Commit
a683b22
·
verified ·
1 Parent(s): 9939585

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -15
app.py CHANGED
@@ -1,23 +1,23 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import torch
4
  import spaces
5
- import threading
6
 
7
- model_name = "INSAIT-Institute/MamayLM-Gemma-2-9B-IT-v0.1"
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
10
 
11
  @spaces.GPU
12
  def respond(message, chat_history, system_message, max_new_tokens, temperature, top_p):
 
13
  prompt = f"{system_message.strip()}\n"
14
  for user, bot in chat_history:
15
  prompt += f"User: {user}\nAssistant: {bot}\n"
16
  prompt += f"User: {message}\nAssistant:"
17
 
18
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
19
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
20
- generation_kwargs = dict(
21
  **inputs,
22
  max_new_tokens=int(max_new_tokens),
23
  pad_token_id=tokenizer.eos_token_id,
@@ -25,17 +25,11 @@ def respond(message, chat_history, system_message, max_new_tokens, temperature,
25
  temperature=float(temperature),
26
  top_p=float(top_p),
27
  eos_token_id=tokenizer.eos_token_id,
28
- streamer=streamer,
29
  )
 
30
 
31
- thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
32
- thread.start()
33
-
34
- partial_text = ""
35
- for new_text in streamer:
36
- partial_text += new_text
37
- response = partial_text.split("User:")[0].strip()
38
- yield response
39
 
40
  def clear_fn():
41
  return None
@@ -60,6 +54,7 @@ chat = gr.ChatInterface(
60
  ],
61
  title="💬 Chat with MamayLM",
62
  description="A multi-turn chat interface for MamayLM-v0.1-9B with configurable parameters.",
63
- theme="soft",
64
  )
 
65
  chat.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import spaces
 
5
 
6
+
7
+ model_name = "INSAIT-Institute/MamayLM-Gemma-2-9B-IT-v0.1"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
10
 
11
  @spaces.GPU
12
  def respond(message, chat_history, system_message, max_new_tokens, temperature, top_p):
13
+
14
  prompt = f"{system_message.strip()}\n"
15
  for user, bot in chat_history:
16
  prompt += f"User: {user}\nAssistant: {bot}\n"
17
  prompt += f"User: {message}\nAssistant:"
18
 
19
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
20
+ output = model.generate(
 
21
  **inputs,
22
  max_new_tokens=int(max_new_tokens),
23
  pad_token_id=tokenizer.eos_token_id,
 
25
  temperature=float(temperature),
26
  top_p=float(top_p),
27
  eos_token_id=tokenizer.eos_token_id,
 
28
  )
29
+ decoded = tokenizer.decode(output[0], skip_special_tokens=True)
30
 
31
+ response = decoded.split("Assistant:")[-1].strip().split("User:")[0].strip()
32
+ return response
 
 
 
 
 
 
33
 
34
  def clear_fn():
35
  return None
 
54
  ],
55
  title="💬 Chat with MamayLM",
56
  description="A multi-turn chat interface for MamayLM-v0.1-9B with configurable parameters.",
57
+ theme="soft"
58
  )
59
+
60
  chat.launch()