Aditi commited on
Commit
66e60d9
·
1 Parent(s): b66e4e3

resolved model issue

Browse files
Files changed (1) hide show
  1. short_answer_generator.py +100 -0
short_answer_generator.py CHANGED
@@ -95,3 +95,103 @@ def main():
95
 
96
  if __name__ == "__main__":
97
  main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  if __name__ == "__main__":
97
  main()
98
+ import torch
99
+ import random
100
+ from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
101
+
102
+ class QuestionGenerator:
103
+ def __init__(self, model_name='distilbert-base-uncased-distilled-squad'):
104
+ """
105
+ Initialize question generation system using a stable QA model
106
+ """
107
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
108
+ self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
109
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
110
+
111
+ # Create QA pipeline
112
+ self.qa_pipeline = pipeline(
113
+ 'question-answering',
114
+ model=self.model,
115
+ tokenizer=self.tokenizer,
116
+ device=0 if self.device == 'cuda' else -1
117
+ )
118
+
119
+ # Sample templates to simulate natural QA generation
120
+ self.question_templates = [
121
+ "What is the main idea of",
122
+ "Who is responsible for",
123
+ "When did this occur",
124
+ "Where does this take place",
125
+ "Why is this important",
126
+ "How does this work",
127
+ "What are the key features of",
128
+ "Explain the significance of",
129
+ "What is the purpose of",
130
+ "Describe the process of"
131
+ ]
132
+
133
+ def generate_questions(self, context, num_questions=3, difficulty='medium'):
134
+ """
135
+ Generate short answer questions based on provided context
136
+ """
137
+ generated_questions = []
138
+ attempts = 0
139
+ max_attempts = num_questions * 10
140
+
141
+ while len(generated_questions) < num_questions and attempts < max_attempts:
142
+ try:
143
+ template = random.choice(self.question_templates)
144
+ words = context.split()
145
+ start_index = random.randint(0, max(0, len(words) - 5))
146
+ snippet = ' '.join(words[start_index:start_index + 5])
147
+ full_question = f"{template} {snippet}?"
148
+
149
+ result = self.qa_pipeline(question=full_question, context=context)
150
+
151
+ # Validate and deduplicate
152
+ if (
153
+ result['answer']
154
+ and len(result['answer']) > 3
155
+ and result['score'] > 0.5
156
+ and not any(q['answer'].lower() == result['answer'].lower() for q in generated_questions)
157
+ ):
158
+ generated_questions.append({
159
+ 'question': full_question,
160
+ 'answer': result['answer'],
161
+ 'confidence': result['score']
162
+ })
163
+ attempts += 1
164
+
165
+ except Exception as e:
166
+ print(f"Question generation error: {e}")
167
+ attempts += 1
168
+
169
+ return generated_questions
170
+
171
+ def display_questions(self, questions):
172
+ print("\n--- Generated Questions ---")
173
+ for idx, q in enumerate(questions, 1):
174
+ print(f"Q{idx}: {q['question']}")
175
+ print(f"Expected keyword: {q['answer']} \n")
176
+
177
+ # Run this if testing standalone
178
+ if __name__ == "__main__":
179
+ print("\n>> Enter the context for question generation: ")
180
+ context = input().strip()
181
+
182
+ while True:
183
+ try:
184
+ num_q = int(input("\n>> How many questions do you want? (1-10): "))
185
+ if 1 <= num_q <= 10:
186
+ break
187
+ print("Please enter a number between 1 and 10.")
188
+ except ValueError:
189
+ print("Invalid input. Please enter a number.")
190
+
191
+ generator = QuestionGenerator()
192
+ questions = generator.generate_questions(context, num_questions=num_q)
193
+
194
+ if questions:
195
+ generator.display_questions(questions)
196
+ else:
197
+ print("❌ Could not generate any questions.")