jonathantiedchen commited on
Commit
394a0e3
Β·
verified Β·
1 Parent(s): e950877

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteriaLi
7
  import importlib
8
  import random
9
  from datasets import load_dataset
 
10
 
11
  # Streamlit UI
12
  st.title("🧠 Math LLM Demo")
@@ -55,12 +56,14 @@ if st.button("Generate Response", key="manual"):
55
 
56
  #MISTRAL PROMPTING
57
  inputs = mistral_tokenizer(prompt, return_tensors="pt").to(mistral.device)
 
 
58
  with torch.no_grad():
59
  outputs = mistral.generate(
60
  **inputs,
61
  max_new_tokens=512,
62
  pad_token_id=mistral_tokenizer.eos_token_id,
63
- eos_token_id=mistral_tokenizer.eos_token_id
64
  )
65
  generated_text = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
66
  if generated_text.startswith(prompt):
@@ -69,8 +72,8 @@ if st.button("Generate Response", key="manual"):
69
  response_only = generated_text.strip()
70
 
71
  st.subheader("πŸ”Ž Prompt")
72
- st.code(prompt)
73
  st.subheader("🧠 Model Output")
74
- st.code(generated_text)
75
  st.subheader("βœ‚οΈ Response Only")
76
  st.success(response_only)
 
7
  import importlib
8
  import random
9
  from datasets import load_dataset
10
+ from utils import SpecificStringStoppingCriteria
11
 
12
  # Streamlit UI
13
  st.title("🧠 Math LLM Demo")
 
56
 
57
  #MISTRAL PROMPTING
58
  inputs = mistral_tokenizer(prompt, return_tensors="pt").to(mistral.device)
59
+ stop_criteria = SpecificStringStoppingCriteria(mistral_tokenizer, generation_util, len(prompt))
60
+ stopping_criteria_list = StoppingCriteriaList([stop_criteria])
61
  with torch.no_grad():
62
  outputs = mistral.generate(
63
  **inputs,
64
  max_new_tokens=512,
65
  pad_token_id=mistral_tokenizer.eos_token_id,
66
+ stopping_criteria=stopping_criteria_list
67
  )
68
  generated_text = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
69
  if generated_text.startswith(prompt):
 
72
  response_only = generated_text.strip()
73
 
74
  st.subheader("πŸ”Ž Prompt")
75
+ st.write(prompt)
76
  st.subheader("🧠 Model Output")
77
+ st.write(generated_text)
78
  st.subheader("βœ‚οΈ Response Only")
79
  st.success(response_only)