heegyu commited on
Commit
cc7fe5e
·
1 Parent(s): c91acd0

generation parameter 변경

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from streamlit_chat import message
3
 
 
4
  @st.cache(allow_output_mutation=True)
5
  def get_pipe():
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -10,35 +11,38 @@ def get_pipe():
10
  return model, tokenizer
11
 
12
  def get_response(tokenizer, model, history, max_context: int = 7, bot_id: str = '1'):
 
13
  context = []
14
  for i, text in enumerate(history):
15
- context.append(f"{i % 2} : {text}</s>")
16
 
17
  if len(context) > max_context:
18
  context = context[-max_context:]
19
- context = "".join(context) + f"{bot_id} : "
20
  inputs = tokenizer(context, return_tensors="pt")
21
 
22
  generation_args = dict(
23
- max_new_tokens=64,
24
  min_length=inputs["input_ids"].shape[1] + 5,
 
25
  eos_token_id=2,
26
  do_sample=True,
27
- top_p=0.6,
28
- temperature=0.8,
29
- repetition_penalty=1.5,
30
  early_stopping=True
31
  )
32
 
33
  outputs = model.generate(**inputs, **generation_args)
34
- response = tokenizer.decode(outputs[0])
35
- print(context)
36
- print(response)
37
- response = response[len(context):].replace("</s>", "")
38
-
 
39
  return response
40
 
41
- st.title("한국어 대화 모델 demo")
42
 
43
  with st.spinner("loading model..."):
44
  model, tokenizer = get_pipe()
 
1
  import streamlit as st
2
  from streamlit_chat import message
3
 
4
+
5
  @st.cache(allow_output_mutation=True)
6
  def get_pipe():
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
11
  return model, tokenizer
12
 
13
  def get_response(tokenizer, model, history, max_context: int = 7, bot_id: str = '1'):
14
+ # print("history:", history)
15
  context = []
16
  for i, text in enumerate(history):
17
+ context.append(f"{i % 2}: {text}</s>")
18
 
19
  if len(context) > max_context:
20
  context = context[-max_context:]
21
+ context = "".join(context) + f"{bot_id}: "
22
  inputs = tokenizer(context, return_tensors="pt")
23
 
24
  generation_args = dict(
25
+ max_new_tokens=128,
26
  min_length=inputs["input_ids"].shape[1] + 5,
27
+ # no_repeat_ngram_size=4,
28
  eos_token_id=2,
29
  do_sample=True,
30
+ top_p=0.95,
31
+ temperature=1.35,
32
+ # repetition_penalty=1.0,
33
  early_stopping=True
34
  )
35
 
36
  outputs = model.generate(**inputs, **generation_args)
37
+ response = tokenizer.decode(outputs[0], skip_special_tokens=False)
38
+ print("Context:", tokenizer.decode(inputs["input_ids"][0]))
39
+ print("Response:", response)
40
+ response = response[len(context):].replace("</s>", "").replace("\n", "")
41
+ response = response.split("<s>")[0]
42
+ # print("Response:", response)
43
  return response
44
 
45
+ st.title("ajoublue-gpt2-medium 한국어 대화 모델 demo")
46
 
47
  with st.spinner("loading model..."):
48
  model, tokenizer = get_pipe()