jonathantiedchen commited on
Commit
33b4451
Β·
verified Β·
1 Parent(s): 8dccc28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -6
app.py CHANGED
@@ -8,11 +8,14 @@ import importlib
8
  import random
9
  from datasets import load_dataset
10
  from utils import SpecificStringStoppingCriteria
 
11
 
12
  # Streamlit UI
13
  st.title("🧠 Math LLM Demo")
14
  st.write("πŸ’¬ Please prompt something!")
15
 
 
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
 
18
  # Some Specifications
@@ -104,13 +107,26 @@ model = models[model_path]["model"]
104
  # - add a history of the prompts similar to chat format
105
  prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?")
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if st.button("Generate Response", key="manual"):
108
  with st.sidebar:
109
  with st.spinner("πŸ”„ Generating..."):
110
 
111
  # Configuration needed for all models
112
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
113
- stop_criteria = SpecificStringStoppingCriteria(tokenizer, generation_util, len(prompt))
114
  stopping_criteria_list = StoppingCriteriaList([stop_criteria])
115
 
116
  # Statement to check model version, different model need different prompting strategy
@@ -124,8 +140,8 @@ if st.button("Generate Response", key="manual"):
124
  stopping_criteria=stopping_criteria_list
125
  )
126
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
127
- if generated_text.startswith(prompt):
128
- response_only = generated_text[len(prompt):].strip()
129
  else:
130
  response_only = generated_text.strip()
131
 
@@ -139,7 +155,7 @@ if st.button("Generate Response", key="manual"):
139
  stopping_criteria=stopping_criteria_list
140
  )
141
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
142
- response_only = generated_text[len(prompt):].strip()
143
 
144
  else:
145
  st.error("⚠️ Problems in identifying the model.")
@@ -147,7 +163,7 @@ if st.button("Generate Response", key="manual"):
147
  response_only = "Error: Model not recognized"
148
 
149
  st.subheader("πŸ”Ž Prompt")
150
- st.write(prompt)
151
  #st.subheader("🧠 Model Output")
152
  #st.write(generated_text)
153
  st.subheader("🧠 Model Output")
 
8
  import random
9
  from datasets import load_dataset
10
  from utils import SpecificStringStoppingCriteria
11
+ from cot import EIGHT_SHOT_PROMPT, FOUR_SHOT_PROMPT
12
 
13
  # Streamlit UI
14
  st.title("🧠 Math LLM Demo")
15
  st.write("πŸ’¬ Please prompt something!")
16
 
17
+ use_cot = st.toggle("Activate feature")
18
+
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
  # Some Specifications
 
107
  # - add a history of the prompts similar to chat format
108
  prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?")
109
 
110
+ if use_cot:
111
+ if 'mistral' in model_choice.lower():
112
+ #use 8 shot prompt
113
+ prompt_template = EIGHT_SHOT_PROMPT
114
+ input_text = prompt_template.format(question=prompt)
115
+
116
+ elif 'small' in model_choice.lower() or 'gpt' in model_choice.lower():
117
+ #use 4s shot prompt
118
+ prompt_template = FOUR_SHOT_PROMPT
119
+ input_text = prompt_template.format(question=prompt)
120
+ else:
121
+ input_text = prompt
122
+
123
  if st.button("Generate Response", key="manual"):
124
  with st.sidebar:
125
  with st.spinner("πŸ”„ Generating..."):
126
 
127
  # Configuration needed for all models
128
+ inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
129
+ stop_criteria = SpecificStringStoppingCriteria(tokenizer, generation_util, len(input_text))
130
  stopping_criteria_list = StoppingCriteriaList([stop_criteria])
131
 
132
  # Statement to check model version, different model need different prompting strategy
 
140
  stopping_criteria=stopping_criteria_list
141
  )
142
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
143
+ if generated_text.startswith(input_text):
144
+ response_only = generated_text[len(input_text):].strip()
145
  else:
146
  response_only = generated_text.strip()
147
 
 
155
  stopping_criteria=stopping_criteria_list
156
  )
157
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
158
+ response_only = generated_text[len(input_text):].strip()
159
 
160
  else:
161
  st.error("⚠️ Problems in identifying the model.")
 
163
  response_only = "Error: Model not recognized"
164
 
165
  st.subheader("πŸ”Ž Prompt")
166
+ st.write(input_text)
167
  #st.subheader("🧠 Model Output")
168
  #st.write(generated_text)
169
  st.subheader("🧠 Model Output")