Mohinikathro commited on
Commit
0693a86
Β·
verified Β·
1 Parent(s): 0ec1ab7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -7
app.py CHANGED
@@ -66,21 +66,36 @@ def identify_subtopic(question, domain):
66
  def generate_question(prompt, domain, state):
67
  full_prompt = system_prompt + "\n" + prompt
68
  inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
 
69
  outputs = model.generate(
70
  inputs["input_ids"],
71
  max_new_tokens=50,
72
- num_return_sequences=1,
73
- no_repeat_ngram_size=2,
74
- top_k=30,
75
  top_p=0.9,
 
76
  temperature=0.7,
77
- do_sample=True,
78
  pad_token_id=tokenizer.eos_token_id,
 
79
  )
80
- question = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  if not question.endswith("?"):
82
- question = question.split("?")[0] + "?"
83
 
 
84
  subtopic = identify_subtopic(question, domain)
85
  if question not in state["asked_questions"] and (subtopic is None or subtopic not in state["asked_subtopics"]):
86
  state["asked_questions"].add(question)
@@ -88,7 +103,7 @@ def generate_question(prompt, domain, state):
88
  state["asked_subtopics"].add(subtopic)
89
  return question
90
  else:
91
- return generate_question(prompt, domain, state) # Retry
92
 
93
  def match_company(user_input):
94
  user_input_lower = user_input.lower()
 
66
  def generate_question(prompt, domain, state):
67
  full_prompt = system_prompt + "\n" + prompt
68
  inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
69
+
70
  outputs = model.generate(
71
  inputs["input_ids"],
72
  max_new_tokens=50,
73
+ do_sample=True,
 
 
74
  top_p=0.9,
75
+ top_k=30,
76
  temperature=0.7,
77
+ no_repeat_ngram_size=2,
78
  pad_token_id=tokenizer.eos_token_id,
79
+ eos_token_id=tokenizer.eos_token_id
80
  )
81
+
82
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
83
+
84
+ # βœ… Remove everything before the actual question (i.e., the prompt text)
85
+ if decoded.startswith(system_prompt.strip()):
86
+ decoded = decoded[len(system_prompt):].strip()
87
+ if prompt.strip() in decoded:
88
+ decoded = decoded.split(prompt.strip())[-1].strip()
89
+
90
+ # βœ… Extract only the question line
91
+ question_lines = decoded.splitlines()
92
+ question = next((line for line in question_lines if "?" in line), decoded).strip()
93
+
94
+ # βœ… Ensure it ends with a "?"
95
  if not question.endswith("?"):
96
+ question = question.split("?")[0].strip() + "?"
97
 
98
+ # βœ… Check for uniqueness
99
  subtopic = identify_subtopic(question, domain)
100
  if question not in state["asked_questions"] and (subtopic is None or subtopic not in state["asked_subtopics"]):
101
  state["asked_questions"].add(question)
 
103
  state["asked_subtopics"].add(subtopic)
104
  return question
105
  else:
106
+ return generate_question(prompt, domain, state) # Try again
107
 
108
  def match_company(user_input):
109
  user_input_lower = user_input.lower()