Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,11 +8,14 @@ import importlib
|
|
8 |
import random
|
9 |
from datasets import load_dataset
|
10 |
from utils import SpecificStringStoppingCriteria
|
|
|
11 |
|
12 |
# Streamlit UI
|
13 |
st.title("π§ Math LLM Demo")
|
14 |
st.write("π¬ Please prompt something!")
|
15 |
|
|
|
|
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
18 |
# Some Specifications
|
@@ -104,13 +107,26 @@ model = models[model_path]["model"]
|
|
104 |
# - add a history of the prompts similar to chat format
|
105 |
prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?")
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
if st.button("Generate Response", key="manual"):
|
108 |
with st.sidebar:
|
109 |
with st.spinner("π Generating..."):
|
110 |
|
111 |
# Configuration needed for all models
|
112 |
-
inputs = tokenizer(
|
113 |
-
stop_criteria = SpecificStringStoppingCriteria(tokenizer, generation_util, len(
|
114 |
stopping_criteria_list = StoppingCriteriaList([stop_criteria])
|
115 |
|
116 |
# Statement to check model version, different model need different prompting strategy
|
@@ -124,8 +140,8 @@ if st.button("Generate Response", key="manual"):
|
|
124 |
stopping_criteria=stopping_criteria_list
|
125 |
)
|
126 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
127 |
-
if generated_text.startswith(
|
128 |
-
response_only = generated_text[len(
|
129 |
else:
|
130 |
response_only = generated_text.strip()
|
131 |
|
@@ -139,7 +155,7 @@ if st.button("Generate Response", key="manual"):
|
|
139 |
stopping_criteria=stopping_criteria_list
|
140 |
)
|
141 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
142 |
-
response_only = generated_text[len(
|
143 |
|
144 |
else:
|
145 |
st.error("β οΈ Problems in identifying the model.")
|
@@ -147,7 +163,7 @@ if st.button("Generate Response", key="manual"):
|
|
147 |
response_only = "Error: Model not recognized"
|
148 |
|
149 |
st.subheader("π Prompt")
|
150 |
-
st.write(
|
151 |
#st.subheader("π§ Model Output")
|
152 |
#st.write(generated_text)
|
153 |
st.subheader("π§ Model Output")
|
|
|
8 |
import random
|
9 |
from datasets import load_dataset
|
10 |
from utils import SpecificStringStoppingCriteria
|
11 |
+
from cot import EIGHT_SHOT_PROMPT, FOUR_SHOT_PROMPT
|
12 |
|
13 |
# Streamlit UI
|
14 |
st.title("π§ Math LLM Demo")
|
15 |
st.write("π¬ Please prompt something!")
|
16 |
|
17 |
+
use_cot = st.toggle("Activate feature")
|
18 |
+
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
|
21 |
# Some Specifications
|
|
|
107 |
# - add a history of the prompts similar to chat format
|
108 |
prompt = st.text_area("Enter your math prompt:", "Jasper has 5 apples and eats 2 of them. How many apples does he have left?")
|
109 |
|
110 |
+
if use_cot:
|
111 |
+
if 'mistral' in model_choice.lower():
|
112 |
+
#use 8 shot prompt
|
113 |
+
prompt_template = EIGHT_SHOT_PROMPT
|
114 |
+
input_text = prompt_template.format(question=prompt)
|
115 |
+
|
116 |
+
elif 'small' in model_choice.lower() or 'gpt' in model_choice.lower():
|
117 |
+
#use 4s shot prompt
|
118 |
+
prompt_template = FOUR_SHOT_PROMPT
|
119 |
+
input_text = prompt_template.format(question=prompt)
|
120 |
+
else:
|
121 |
+
input_text = prompt
|
122 |
+
|
123 |
if st.button("Generate Response", key="manual"):
|
124 |
with st.sidebar:
|
125 |
with st.spinner("π Generating..."):
|
126 |
|
127 |
# Configuration needed for all models
|
128 |
+
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
|
129 |
+
stop_criteria = SpecificStringStoppingCriteria(tokenizer, generation_util, len(input_text))
|
130 |
stopping_criteria_list = StoppingCriteriaList([stop_criteria])
|
131 |
|
132 |
# Statement to check model version, different model need different prompting strategy
|
|
|
140 |
stopping_criteria=stopping_criteria_list
|
141 |
)
|
142 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
143 |
+
if generated_text.startswith(input_text):
|
144 |
+
response_only = generated_text[len(input_text):].strip()
|
145 |
else:
|
146 |
response_only = generated_text.strip()
|
147 |
|
|
|
155 |
stopping_criteria=stopping_criteria_list
|
156 |
)
|
157 |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
158 |
+
response_only = generated_text[len(input_text):].strip()
|
159 |
|
160 |
else:
|
161 |
st.error("β οΈ Problems in identifying the model.")
|
|
|
163 |
response_only = "Error: Model not recognized"
|
164 |
|
165 |
st.subheader("π Prompt")
|
166 |
+
st.write(input_text)
|
167 |
#st.subheader("π§ Model Output")
|
168 |
#st.write(generated_text)
|
169 |
st.subheader("π§ Model Output")
|