Spaces:
Runtime error
Runtime error
Ubuntu
commited on
Commit
·
c5556d8
1
Parent(s):
fa25719
update
Browse files
app.py
CHANGED
@@ -17,6 +17,7 @@ def infer(
|
|
17 |
top_p=1.0,
|
18 |
top_k=40,
|
19 |
num_completions=1,
|
|
|
20 |
seed=42,
|
21 |
stop="\n"
|
22 |
):
|
@@ -28,6 +29,7 @@ def infer(
|
|
28 |
temperature = float(temperature)
|
29 |
top_p = float(top_p)
|
30 |
top_k = int(top_k)
|
|
|
31 |
stop = stop.split(";")
|
32 |
seed = seed
|
33 |
|
@@ -36,6 +38,7 @@ def infer(
|
|
36 |
assert 0.0 <= temperature <= 10.0
|
37 |
assert 0.0 <= top_p <= 1.0
|
38 |
assert 1 <= top_k <= 1000
|
|
|
39 |
|
40 |
if temperature == 0.0:
|
41 |
temperature = 0.01
|
@@ -48,6 +51,7 @@ def infer(
|
|
48 |
"top_k": top_k,
|
49 |
"temperature": temperature,
|
50 |
"max_tokens": max_new_tokens,
|
|
|
51 |
"stop": stop,
|
52 |
}
|
53 |
print(f"send: {datetime.now()}")
|
@@ -223,6 +227,7 @@ def main():
|
|
223 |
if 'preset' not in st.session_state:
|
224 |
st.session_state.preset = "Sentiment Analysis"
|
225 |
st.session_state.top_k = "40"
|
|
|
226 |
st.session_state.stop = r'\n'
|
227 |
set_preset()
|
228 |
|
@@ -252,6 +257,7 @@ def main():
|
|
252 |
top_p = st.text_input('top_p', st.session_state.top_p)
|
253 |
# num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
|
254 |
num_completions = "1"
|
|
|
255 |
stop = st.text_input('stop, split by;', st.session_state.stop)
|
256 |
# seed = st.text_input('seed', "42")
|
257 |
seed = "42"
|
@@ -275,7 +281,8 @@ def main():
|
|
275 |
generated_area.markdown("<b>" + to_md(prompt) + "</b>", unsafe_allow_html=True)
|
276 |
report_text = infer(
|
277 |
prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
|
278 |
-
num_completions=num_completions,
|
|
|
279 |
)
|
280 |
generated_area.markdown("<b>" + to_md(prompt) + "</b><mark style='background-color: #cbeacd'>" + to_md(report_text)+"</mark>", unsafe_allow_html=True)
|
281 |
|
|
|
17 |
top_p=1.0,
|
18 |
top_k=40,
|
19 |
num_completions=1,
|
20 |
+
repetition_penalty=1.0,
|
21 |
seed=42,
|
22 |
stop="\n"
|
23 |
):
|
|
|
29 |
temperature = float(temperature)
|
30 |
top_p = float(top_p)
|
31 |
top_k = int(top_k)
|
32 |
+
repetition_penalty = float(repetition_penalty)
|
33 |
stop = stop.split(";")
|
34 |
seed = seed
|
35 |
|
|
|
38 |
assert 0.0 <= temperature <= 10.0
|
39 |
assert 0.0 <= top_p <= 1.0
|
40 |
assert 1 <= top_k <= 1000
|
41 |
+
assert 0.9 <= repetition_penalty <= 3.0
|
42 |
|
43 |
if temperature == 0.0:
|
44 |
temperature = 0.01
|
|
|
51 |
"top_k": top_k,
|
52 |
"temperature": temperature,
|
53 |
"max_tokens": max_new_tokens,
|
54 |
+
"repetition_penalty": repetition_penalty,
|
55 |
"stop": stop,
|
56 |
}
|
57 |
print(f"send: {datetime.now()}")
|
|
|
227 |
if 'preset' not in st.session_state:
|
228 |
st.session_state.preset = "Sentiment Analysis"
|
229 |
st.session_state.top_k = "40"
|
230 |
+
st.session_state.repetition_penalty = "1.0"
|
231 |
st.session_state.stop = r'\n'
|
232 |
set_preset()
|
233 |
|
|
|
257 |
top_p = st.text_input('top_p', st.session_state.top_p)
|
258 |
# num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
|
259 |
num_completions = "1"
|
260 |
+
repetition_penalty = st.text_input('repetition_penalty', st.session_state.repetition_penalty)
|
261 |
stop = st.text_input('stop, split by;', st.session_state.stop)
|
262 |
# seed = st.text_input('seed', "42")
|
263 |
seed = "42"
|
|
|
281 |
generated_area.markdown("<b>" + to_md(prompt) + "</b>", unsafe_allow_html=True)
|
282 |
report_text = infer(
|
283 |
prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k,
|
284 |
+
num_completions=num_completions, repetition_penalty=repetition_penalty,
|
285 |
+
seed=seed, stop=literal_eval("'''"+stop+"'''"),
|
286 |
)
|
287 |
generated_area.markdown("<b>" + to_md(prompt) + "</b><mark style='background-color: #cbeacd'>" + to_md(report_text)+"</mark>", unsafe_allow_html=True)
|
288 |
|