Mohinikathro commited on
Commit
24e02fc
Β·
verified Β·
1 Parent(s): 0693a86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -24
app.py CHANGED
@@ -26,8 +26,8 @@ model = AutoModelForCausalLM.from_pretrained(
26
  model = torch.compile(model)
27
 
28
  # System prompt for consistent question generation
29
- system_prompt = """
30
- You are conducting a mock technical interview. Your task is to generate clear, concise, and unique interview questions based on the given domain and round. Follow these rules:
31
 
32
  1. Only output one question β€” do not include explanations, elaborations, or surrounding text.
33
  2. Do not use any labels like "Follow-up Question" or "Question:" in your output. Just the raw question.
@@ -64,38 +64,34 @@ def identify_subtopic(question, domain):
64
  return None
65
 
66
  def generate_question(prompt, domain, state):
67
- full_prompt = system_prompt + "\n" + prompt
68
- inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
69
 
70
  outputs = model.generate(
71
- inputs["input_ids"],
72
- max_new_tokens=50,
73
- do_sample=True,
 
74
  top_p=0.9,
75
  top_k=30,
76
- temperature=0.7,
 
 
77
  no_repeat_ngram_size=2,
78
- pad_token_id=tokenizer.eos_token_id,
79
- eos_token_id=tokenizer.eos_token_id
80
  )
81
 
82
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
83
 
84
- # βœ… Remove everything before the actual question (i.e., the prompt text)
85
- if decoded.startswith(system_prompt.strip()):
86
- decoded = decoded[len(system_prompt):].strip()
87
- if prompt.strip() in decoded:
88
- decoded = decoded.split(prompt.strip())[-1].strip()
89
-
90
- # βœ… Extract only the question line
91
- question_lines = decoded.splitlines()
92
- question = next((line for line in question_lines if "?" in line), decoded).strip()
93
 
94
- # βœ… Ensure it ends with a "?"
95
- if not question.endswith("?"):
96
- question = question.split("?")[0].strip() + "?"
 
97
 
98
- # βœ… Check for uniqueness
99
  subtopic = identify_subtopic(question, domain)
100
  if question not in state["asked_questions"] and (subtopic is None or subtopic not in state["asked_subtopics"]):
101
  state["asked_questions"].add(question)
@@ -103,7 +99,7 @@ def generate_question(prompt, domain, state):
103
  state["asked_subtopics"].add(subtopic)
104
  return question
105
  else:
106
- return generate_question(prompt, domain, state) # Try again
107
 
108
  def match_company(user_input):
109
  user_input_lower = user_input.lower()
 
26
  model = torch.compile(model)
27
 
28
  # System prompt for consistent question generation
29
+ system_prompt = f"""
30
+ You are conducting a {round_type.lower()} interview for a position in {domain} at {company}. Generate one concise and unique question: Follow these rules:
31
 
32
  1. Only output one question β€” do not include explanations, elaborations, or surrounding text.
33
  2. Do not use any labels like "Follow-up Question" or "Question:" in your output. Just the raw question.
 
64
  return None
65
 
66
  def generate_question(prompt, domain, state):
67
+ full_prompt = f"{system_prompt.strip()}\n{prompt.strip()}"
68
+ inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(device)
69
 
70
  outputs = model.generate(
71
+ input_ids=inputs["input_ids"],
72
+ attention_mask=inputs["attention_mask"],
73
+ max_new_tokens=60,
74
+ temperature=0.7,
75
  top_p=0.9,
76
  top_k=30,
77
+ do_sample=True,
78
+ pad_token_id=tokenizer.pad_token_id,
79
+ eos_token_id=tokenizer.eos_token_id,
80
  no_repeat_ngram_size=2,
 
 
81
  )
82
 
83
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
84
 
85
+ # Step 1: Strip system prompt and user prompt from output
86
+ if full_prompt in decoded:
87
+ decoded = decoded.split(full_prompt)[-1].strip()
 
 
 
 
 
 
88
 
89
+ # Step 2: Take only the first proper sentence ending with '?'
90
+ lines = decoded.splitlines()
91
+ question = next((line for line in lines if "?" in line), decoded)
92
+ question = question.split("?")[0].strip() + "?"
93
 
94
+ # Step 3: Check for duplicates
95
  subtopic = identify_subtopic(question, domain)
96
  if question not in state["asked_questions"] and (subtopic is None or subtopic not in state["asked_subtopics"]):
97
  state["asked_questions"].add(question)
 
99
  state["asked_subtopics"].add(subtopic)
100
  return question
101
  else:
102
+ return generate_question(prompt, domain, state) # Retry generation
103
 
104
  def match_company(user_input):
105
  user_input_lower = user_input.lower()