Spaces:
Runtime error
Runtime error
generation parameter 변경
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
from streamlit_chat import message
|
3 |
|
|
|
4 |
@st.cache(allow_output_mutation=True)
|
5 |
def get_pipe():
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
@@ -10,35 +11,38 @@ def get_pipe():
|
|
10 |
return model, tokenizer
|
11 |
|
12 |
def get_response(tokenizer, model, history, max_context: int = 7, bot_id: str = '1'):
|
|
|
13 |
context = []
|
14 |
for i, text in enumerate(history):
|
15 |
-
context.append(f"{i % 2}
|
16 |
|
17 |
if len(context) > max_context:
|
18 |
context = context[-max_context:]
|
19 |
-
context = "".join(context) + f"{bot_id}
|
20 |
inputs = tokenizer(context, return_tensors="pt")
|
21 |
|
22 |
generation_args = dict(
|
23 |
-
max_new_tokens=
|
24 |
min_length=inputs["input_ids"].shape[1] + 5,
|
|
|
25 |
eos_token_id=2,
|
26 |
do_sample=True,
|
27 |
-
top_p=0.
|
28 |
-
temperature=
|
29 |
-
repetition_penalty=1.
|
30 |
early_stopping=True
|
31 |
)
|
32 |
|
33 |
outputs = model.generate(**inputs, **generation_args)
|
34 |
-
response = tokenizer.decode(outputs[0])
|
35 |
-
print(
|
36 |
-
print(response)
|
37 |
-
response = response[len(context):].replace("</s>", "")
|
38 |
-
|
|
|
39 |
return response
|
40 |
|
41 |
-
st.title("한국어 대화 모델 demo")
|
42 |
|
43 |
with st.spinner("loading model..."):
|
44 |
model, tokenizer = get_pipe()
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit_chat import message
|
3 |
|
4 |
+
|
5 |
@st.cache(allow_output_mutation=True)
|
6 |
def get_pipe():
|
7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
11 |
return model, tokenizer
|
12 |
|
13 |
def get_response(tokenizer, model, history, max_context: int = 7, bot_id: str = '1'):
|
14 |
+
# print("history:", history)
|
15 |
context = []
|
16 |
for i, text in enumerate(history):
|
17 |
+
context.append(f"{i % 2}: {text}</s>")
|
18 |
|
19 |
if len(context) > max_context:
|
20 |
context = context[-max_context:]
|
21 |
+
context = "".join(context) + f"{bot_id}: "
|
22 |
inputs = tokenizer(context, return_tensors="pt")
|
23 |
|
24 |
generation_args = dict(
|
25 |
+
max_new_tokens=128,
|
26 |
min_length=inputs["input_ids"].shape[1] + 5,
|
27 |
+
# no_repeat_ngram_size=4,
|
28 |
eos_token_id=2,
|
29 |
do_sample=True,
|
30 |
+
top_p=0.95,
|
31 |
+
temperature=1.35,
|
32 |
+
# repetition_penalty=1.0,
|
33 |
early_stopping=True
|
34 |
)
|
35 |
|
36 |
outputs = model.generate(**inputs, **generation_args)
|
37 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
38 |
+
print("Context:", tokenizer.decode(inputs["input_ids"][0]))
|
39 |
+
print("Response:", response)
|
40 |
+
response = response[len(context):].replace("</s>", "").replace("\n", "")
|
41 |
+
response = response.split("<s>")[0]
|
42 |
+
# print("Response:", response)
|
43 |
return response
|
44 |
|
45 |
+
st.title("ajoublue-gpt2-medium 한국어 대화 모델 demo")
|
46 |
|
47 |
with st.spinner("loading model..."):
|
48 |
model, tokenizer = get_pipe()
|