Spaces:
Sleeping
Sleeping
Aditi
commited on
Commit
·
b57b4eb
1
Parent(s):
2b57a7b
Removed old scripts and added quiz logic
Browse files- mcq_generator.ipynb +0 -423
- quiz_logic.py +78 -0
- short_answer_generator.ipynb +0 -603
- true_false_generator.ipynb +0 -250
- true_false_generator.py +0 -1
- truefalse_quiz.py +98 -0
mcq_generator.ipynb
DELETED
@@ -1,423 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"nbformat": 4,
|
3 |
-
"nbformat_minor": 0,
|
4 |
-
"metadata": {
|
5 |
-
"colab": {
|
6 |
-
"provenance": [],
|
7 |
-
"gpuType": "T4",
|
8 |
-
"authorship_tag": "ABX9TyNlLgN36uc2PRyXWiLUUS03",
|
9 |
-
"include_colab_link": true
|
10 |
-
},
|
11 |
-
"kernelspec": {
|
12 |
-
"name": "python3",
|
13 |
-
"display_name": "Python 3"
|
14 |
-
},
|
15 |
-
"language_info": {
|
16 |
-
"name": "python"
|
17 |
-
},
|
18 |
-
"accelerator": "GPU"
|
19 |
-
},
|
20 |
-
"cells": [
|
21 |
-
{
|
22 |
-
"cell_type": "markdown",
|
23 |
-
"metadata": {
|
24 |
-
"id": "view-in-github",
|
25 |
-
"colab_type": "text"
|
26 |
-
},
|
27 |
-
"source": [
|
28 |
-
"<a href=\"https://colab.research.google.com/github/DishaKushwah/custom-quiz-generator/blob/main/mcq_generator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
29 |
-
]
|
30 |
-
},
|
31 |
-
{
|
32 |
-
"cell_type": "code",
|
33 |
-
"source": [
|
34 |
-
"## MCQS\n",
|
35 |
-
"import torch\n",
|
36 |
-
"from transformers import (\n",
|
37 |
-
" AutoTokenizer, AutoModel\n",
|
38 |
-
")\n",
|
39 |
-
"from transformers.pipelines import pipeline\n",
|
40 |
-
"from transformers.models.t5 import T5ForConditionalGeneration, T5Tokenizer\n",
|
41 |
-
"\n",
|
42 |
-
"from sentence_transformers import SentenceTransformer\n",
|
43 |
-
"import spacy\n",
|
44 |
-
"import numpy as np\n",
|
45 |
-
"from sklearn.metrics.pairwise import cosine_similarity\n",
|
46 |
-
"import re\n",
|
47 |
-
"import random\n",
|
48 |
-
"from typing import List, Dict, Tuple\n",
|
49 |
-
"import nltk\n",
|
50 |
-
"from nltk.corpus import wordnet\n",
|
51 |
-
"import string\n",
|
52 |
-
"\n",
|
53 |
-
"class MultipleChoiceQuestionGenerator:\n",
|
54 |
-
" def __init__(self):\n",
|
55 |
-
" \"\"\"Initialize the MCQ generator with advanced models.\"\"\"\n",
|
56 |
-
" self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
57 |
-
"\n",
|
58 |
-
" # Load T5 model for question generation\n",
|
59 |
-
" self.qg_model_name = \"valhalla/t5-base-qg-hl\"\n",
|
60 |
-
" self.qg_tokenizer = T5Tokenizer.from_pretrained(self.qg_model_name)\n",
|
61 |
-
" self.qg_model = T5ForConditionalGeneration.from_pretrained(self.qg_model_name).to(self.device)\n",
|
62 |
-
"\n",
|
63 |
-
" # Load question-answering pipeline for answer validation\n",
|
64 |
-
" self.qa_pipeline = pipeline(\n",
|
65 |
-
" \"question-answering\",\n",
|
66 |
-
" model=\"deepset/roberta-large-squad2\",\n",
|
67 |
-
" tokenizer=\"deepset/roberta-large-squad2\",\n",
|
68 |
-
" device=0 if torch.cuda.is_available() else -1\n",
|
69 |
-
" )\n",
|
70 |
-
"\n",
|
71 |
-
" # Load sentence transformer for semantic similarity (distractor generation)\n",
|
72 |
-
" self.sentence_model = SentenceTransformer('all-MiniLM-L6-v2')\n",
|
73 |
-
"\n",
|
74 |
-
" # Load spaCy for NLP processing\n",
|
75 |
-
" try:\n",
|
76 |
-
" self.nlp = spacy.load(\"en_core_web_sm\")\n",
|
77 |
-
" except OSError:\n",
|
78 |
-
" print(\"Please install spaCy English model: python -m spacy download en_core_web_sm\")\n",
|
79 |
-
" self.nlp = None\n",
|
80 |
-
"\n",
|
81 |
-
" # Load fill-mask pipeline for generating distractors\n",
|
82 |
-
" self.fill_mask = pipeline(\"fill-mask\",model=\"roberta-large\",tokenizer=\"roberta-large\",device=0 if torch.cuda.is_available() else -1)\n",
|
83 |
-
"\n",
|
84 |
-
" # Download NLTK data\n",
|
85 |
-
" try:\n",
|
86 |
-
" nltk.download('wordnet', quiet=True)\n",
|
87 |
-
" nltk.download('omw-1.4', quiet=True)\n",
|
88 |
-
" except:\n",
|
89 |
-
" pass\n",
|
90 |
-
"\n",
|
91 |
-
" def extract_key_information(self, text: str) -> Dict:\n",
|
92 |
-
" \"\"\"Extract key information from text for question generation.\"\"\"\n",
|
93 |
-
" if not self.nlp:\n",
|
94 |
-
" return {\"entities\": [], \"noun_chunks\": [], \"sentences\": []}\n",
|
95 |
-
"\n",
|
96 |
-
" doc = self.nlp(text)\n",
|
97 |
-
" # Extract named entities\n",
|
98 |
-
" entities = []\n",
|
99 |
-
" for ent in doc.ents:\n",
|
100 |
-
" if ent.label_ in ['PERSON', 'ORG', 'GPE', 'DATE', 'EVENT', 'WORK_OF_ART', 'CARDINAL', 'ORDINAL']:\n",
|
101 |
-
" entities.append({'text': ent.text,'label': ent.label_,'start': ent.start_char,'end': ent.end_char})\n",
|
102 |
-
"\n",
|
103 |
-
" # Extract noun chunks\n",
|
104 |
-
" noun_chunks = [chunk.text for chunk in doc.noun_chunks if len(chunk.text.split()) <= 4]\n",
|
105 |
-
"\n",
|
106 |
-
" # Extract sentences\n",
|
107 |
-
" sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.split()) > 5]\n",
|
108 |
-
"\n",
|
109 |
-
" return {\"entities\": entities,\"noun_chunks\": noun_chunks,\"sentences\": sentences}\n",
|
110 |
-
"\n",
|
111 |
-
" def generate_question_from_context(self, context: str, answer_text: str) -> str:\n",
|
112 |
-
" \"\"\"Generate a question given context and answer.\"\"\"\n",
|
113 |
-
" # Highlight the answer in the context for T5\n",
|
114 |
-
" highlighted_context = context.replace(answer_text, f\"<hl>{answer_text}<hl>\")\n",
|
115 |
-
" input_text = f\"generate question: {highlighted_context}\"\n",
|
116 |
-
" inputs = self.qg_tokenizer.encode_plus(input_text,max_length=512,truncation=True,padding=True,return_tensors=\"pt\").to(self.device)\n",
|
117 |
-
"\n",
|
118 |
-
" with torch.no_grad():\n",
|
119 |
-
" outputs = self.qg_model.generate(inputs[\"input_ids\"],attention_mask=inputs[\"attention_mask\"],max_length=64,num_beams=4,temperature=0.8,do_sample=True,early_stopping=True)\n",
|
120 |
-
"\n",
|
121 |
-
" question = self.qg_tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
122 |
-
" return question\n",
|
123 |
-
"\n",
|
124 |
-
" def generate_distractors_semantic(self, correct_answer: str, context: str, num_distractors: int = 3) -> List[str]:\n",
|
125 |
-
" \"\"\"Generate distractors using semantic similarity and context understanding.\"\"\"\n",
|
126 |
-
" distractors = []\n",
|
127 |
-
"\n",
|
128 |
-
" # Method 1: Use fill-mask to generate contextually similar options\n",
|
129 |
-
" try:\n",
|
130 |
-
" # Replace answer with mask in context\n",
|
131 |
-
" masked_context = context.replace(correct_answer, \"<mask>\")\n",
|
132 |
-
" if \"<mask>\" in masked_context:\n",
|
133 |
-
" predictions = self.fill_mask(masked_context, top_k=20)\n",
|
134 |
-
" for pred in predictions:\n",
|
135 |
-
" candidate = pred['token_str'].strip()\n",
|
136 |
-
" if (candidate != correct_answer and\n",
|
137 |
-
" candidate.lower() != correct_answer.lower() and\n",
|
138 |
-
" len(candidate) > 1 and\n",
|
139 |
-
" candidate not in distractors):\n",
|
140 |
-
" distractors.append(candidate)\n",
|
141 |
-
" if len(distractors) >= num_distractors:\n",
|
142 |
-
" break\n",
|
143 |
-
" except:\n",
|
144 |
-
" pass\n",
|
145 |
-
"\n",
|
146 |
-
" # Method 2: Extract similar entities from context\n",
|
147 |
-
" if self.nlp and len(distractors) < num_distractors:\n",
|
148 |
-
" doc = self.nlp(context)\n",
|
149 |
-
" answer_doc = self.nlp(correct_answer)\n",
|
150 |
-
"\n",
|
151 |
-
" # Get answer entity type\n",
|
152 |
-
" answer_label = None\n",
|
153 |
-
" for ent in answer_doc.ents:\n",
|
154 |
-
" answer_label = ent.label_\n",
|
155 |
-
" break\n",
|
156 |
-
"\n",
|
157 |
-
" # Find similar entities\n",
|
158 |
-
" for ent in doc.ents:\n",
|
159 |
-
" if (ent.label_ == answer_label and ent.text != correct_answer and ent.text not in distractors):\n",
|
160 |
-
" distractors.append(ent.text)\n",
|
161 |
-
" if len(distractors) >= num_distractors:\n",
|
162 |
-
" break\n",
|
163 |
-
"\n",
|
164 |
-
" # Method 3: Generate using WordNet synonyms and related words\n",
|
165 |
-
" if len(distractors) < num_distractors:\n",
|
166 |
-
" try:\n",
|
167 |
-
" words = correct_answer.split()\n",
|
168 |
-
" for word in words:\n",
|
169 |
-
" synsets = wordnet.synsets(word)\n",
|
170 |
-
" for synset in synsets[:3]:\n",
|
171 |
-
" for lemma in synset.lemmas()[:2]:\n",
|
172 |
-
" candidate = lemma.name().replace('_', ' ')\n",
|
173 |
-
" if (candidate != correct_answer and\n",
|
174 |
-
" candidate.lower() != correct_answer.lower() and\n",
|
175 |
-
" candidate not in distractors):\n",
|
176 |
-
" distractors.append(candidate)\n",
|
177 |
-
" if len(distractors) >= num_distractors:\n",
|
178 |
-
" break\n",
|
179 |
-
" if len(distractors) >= num_distractors:\n",
|
180 |
-
" break\n",
|
181 |
-
" if len(distractors) >= num_distractors:\n",
|
182 |
-
" break\n",
|
183 |
-
" except:\n",
|
184 |
-
" pass\n",
|
185 |
-
"\n",
|
186 |
-
" # Method 4: Generate plausible distractors based on answer type\n",
|
187 |
-
" if len(distractors) < num_distractors:\n",
|
188 |
-
" distractors.extend(self.generate_type_based_distractors(correct_answer, context))\n",
|
189 |
-
"\n",
|
190 |
-
" # Remove duplicates and return\n",
|
191 |
-
" unique_distractors = []\n",
|
192 |
-
" seen = set()\n",
|
193 |
-
" for d in distractors:\n",
|
194 |
-
" if d.lower() not in seen and d.lower() != correct_answer.lower():\n",
|
195 |
-
" seen.add(d.lower())\n",
|
196 |
-
" unique_distractors.append(d)\n",
|
197 |
-
" return unique_distractors[:num_distractors]\n",
|
198 |
-
"\n",
|
199 |
-
" def validate_mcq_quality(self, question: str, correct_answer: str, distractors: List[str], context: str) -> Dict:\n",
|
200 |
-
" \"\"\"Validate the quality of generated MCQ.\"\"\"\n",
|
201 |
-
" # Check if the question can be answered correctly\n",
|
202 |
-
" try:\n",
|
203 |
-
" qa_result = self.qa_pipeline(question=question, context=context)\n",
|
204 |
-
" predicted_answer = qa_result['answer']\n",
|
205 |
-
" confidence = qa_result['score']\n",
|
206 |
-
"\n",
|
207 |
-
" # Check if predicted answer matches or is similar to correct answer\n",
|
208 |
-
" similarity_threshold = 0.7\n",
|
209 |
-
" correct_embedding = self.sentence_model.encode([correct_answer])\n",
|
210 |
-
" predicted_embedding = self.sentence_model.encode([predicted_answer])\n",
|
211 |
-
" similarity = cosine_similarity(correct_embedding, predicted_embedding)[0][0]\n",
|
212 |
-
" is_answerable = similarity > similarity_threshold or correct_answer.lower() in predicted_answer.lower()\n",
|
213 |
-
"\n",
|
214 |
-
" except:\n",
|
215 |
-
" is_answerable = False\n",
|
216 |
-
" confidence = 0.0\n",
|
217 |
-
" similarity = 0.0\n",
|
218 |
-
"\n",
|
219 |
-
" # Check distractor quality\n",
|
220 |
-
" if len(distractors) > 0:\n",
|
221 |
-
" distractor_embeddings = self.sentence_model.encode(distractors)\n",
|
222 |
-
" correct_embedding = self.sentence_model.encode([correct_answer])\n",
|
223 |
-
"\n",
|
224 |
-
" # Calculate similarity between distractors and correct answer\n",
|
225 |
-
" similarities = cosine_similarity(correct_embedding, distractor_embeddings)[0]\n",
|
226 |
-
" avg_distractor_similarity = np.mean(similarities)\n",
|
227 |
-
"\n",
|
228 |
-
" # Good distractors should be somewhat similar but not too similar\n",
|
229 |
-
" distractor_quality = \"good\" if 0.3 < avg_distractor_similarity < 0.8 else \"poor\"\n",
|
230 |
-
" else:\n",
|
231 |
-
" distractor_quality = \"poor\"\n",
|
232 |
-
" avg_distractor_similarity = 0.0\n",
|
233 |
-
"\n",
|
234 |
-
" return {\"is_answerable\": is_answerable,\"confidence\": confidence,\"answer_similarity\": similarity,\"distractor_quality\": distractor_quality,\"avg_distractor_similarity\": avg_distractor_similarity }\n",
|
235 |
-
"\n",
|
236 |
-
" def generate_mcq(self, context: str, num_questions: int = 5) -> List[Dict]:\n",
|
237 |
-
" \"\"\"Generate multiple choice questions from context.\"\"\"\n",
|
238 |
-
" mcqs = []\n",
|
239 |
-
"\n",
|
240 |
-
" # Extract key information\n",
|
241 |
-
" key_info = self.extract_key_information(context)\n",
|
242 |
-
"\n",
|
243 |
-
" # Generate questions from entities\n",
|
244 |
-
" for entity in key_info[\"entities\"][:num_questions]:\n",
|
245 |
-
" correct_answer = entity[\"text\"]\n",
|
246 |
-
"\n",
|
247 |
-
" # Generate question\n",
|
248 |
-
" question = self.generate_question_from_context(context, correct_answer)\n",
|
249 |
-
"\n",
|
250 |
-
" # Generate distractors\n",
|
251 |
-
" distractors = self.generate_distractors_semantic(correct_answer, context, 3)\n",
|
252 |
-
"\n",
|
253 |
-
" # Skip if not enough distractors\n",
|
254 |
-
" if len(distractors) < 2:\n",
|
255 |
-
" continue\n",
|
256 |
-
"\n",
|
257 |
-
" # Validate quality\n",
|
258 |
-
" quality = self.validate_mcq_quality(question, correct_answer, distractors, context)\n",
|
259 |
-
"\n",
|
260 |
-
" # Create options and shuffle\n",
|
261 |
-
" options = [correct_answer] + distractors[:3]\n",
|
262 |
-
" random.shuffle(options)\n",
|
263 |
-
" correct_option = chr(65 + options.index(correct_answer)) # A, B, C, D\n",
|
264 |
-
"\n",
|
265 |
-
" mcq = {\n",
|
266 |
-
" \"question\": question,\n",
|
267 |
-
" \"options\": {\"A\": options[0],\"B\": options[1],\"C\": options[2] if len(options) > 2 else \"None of the above\",\"D\": options[3] if len(options) > 3 else \"All of the above\"},\n",
|
268 |
-
" \"correct_answer\": correct_option,\n",
|
269 |
-
" \"correct_text\": correct_answer,\n",
|
270 |
-
" \"entity_type\": entity[\"label\"],\n",
|
271 |
-
" \"quality_score\": quality[\"confidence\"],\n",
|
272 |
-
" \"is_answerable\": quality[\"is_answerable\"]\n",
|
273 |
-
" }\n",
|
274 |
-
"\n",
|
275 |
-
" # Only include high-quality MCQs\n",
|
276 |
-
" if quality[\"is_answerable\"] and quality[\"confidence\"] > 0.3:\n",
|
277 |
-
" mcqs.append(mcq)\n",
|
278 |
-
"\n",
|
279 |
-
" # Generate additional questions from noun chunks if needed\n",
|
280 |
-
" if len(mcqs) < num_questions:\n",
|
281 |
-
" for chunk in key_info[\"noun_chunks\"][:num_questions - len(mcqs)]:\n",
|
282 |
-
" question = self.generate_question_from_context(context, chunk)\n",
|
283 |
-
" distractors = self.generate_distractors_semantic(chunk, context, 3)\n",
|
284 |
-
"\n",
|
285 |
-
" if len(distractors) >= 2:\n",
|
286 |
-
" quality = self.validate_mcq_quality(question, chunk, distractors, context)\n",
|
287 |
-
"\n",
|
288 |
-
" if quality[\"is_answerable\"] and quality[\"confidence\"] > 0.2:\n",
|
289 |
-
" options = [chunk] + distractors[:3]\n",
|
290 |
-
" random.shuffle(options)\n",
|
291 |
-
" correct_option = chr(65 + options.index(chunk))\n",
|
292 |
-
"\n",
|
293 |
-
" mcq = {\n",
|
294 |
-
" \"question\": question,\n",
|
295 |
-
" \"options\": {\"A\": options[0],\"B\": options[1],\"C\": options[2] if len(options) > 2 else \"None of the above\",\"D\": options[3] if len(options) > 3 else \"All of the above\"},\n",
|
296 |
-
" \"correct_answer\": correct_option,\n",
|
297 |
-
" \"correct_text\": chunk,\n",
|
298 |
-
" \"entity_type\": \"NOUN_CHUNK\",\n",
|
299 |
-
" \"quality_score\": quality[\"confidence\"],\n",
|
300 |
-
" \"is_answerable\": quality[\"is_answerable\"]\n",
|
301 |
-
" }\n",
|
302 |
-
" mcqs.append(mcq)\n",
|
303 |
-
"\n",
|
304 |
-
" # Sort by quality score and return\n",
|
305 |
-
" mcqs.sort(key=lambda x: x[\"quality_score\"], reverse=True)\n",
|
306 |
-
" return mcqs[:num_questions]\n",
|
307 |
-
"\n",
|
308 |
-
"def main():\n",
|
309 |
-
" \"\"\"Main function to demonstrate the MCQ generator.\"\"\"\n",
|
310 |
-
" generator = MultipleChoiceQuestionGenerator()\n",
|
311 |
-
" print(\"Multiple Choice Question Generator\")\n",
|
312 |
-
"\n",
|
313 |
-
" # Get user input\n",
|
314 |
-
" user_context = input(\"Enter your context: \").strip()\n",
|
315 |
-
" try:\n",
|
316 |
-
" num_questions = int(input(\"Number of MCQs to generate (default 5): \") or \"5\")\n",
|
317 |
-
" except ValueError:\n",
|
318 |
-
" num_questions = 5\n",
|
319 |
-
" print(f\"\\nGenerating {num_questions} multiple choice questions...\")\n",
|
320 |
-
"\n",
|
321 |
-
" # Generate MCQs\n",
|
322 |
-
" mcqs = generator.generate_mcq(user_context, num_questions)\n",
|
323 |
-
" # Display results\n",
|
324 |
-
" if mcqs:\n",
|
325 |
-
" for i, mcq in enumerate(mcqs, 1):\n",
|
326 |
-
" print(f\"\\nQuestion {i}: \")\n",
|
327 |
-
" print(f\"Q: {mcq['question']}\")\n",
|
328 |
-
" print()\n",
|
329 |
-
" for option, text in mcq['options'].items():\n",
|
330 |
-
" print(f\"{option}) {text}\")\n",
|
331 |
-
" print(f\"\\nCorrect Answer: {mcq['correct_answer']}) {mcq['correct_text']}\")\n",
|
332 |
-
" else:\n",
|
333 |
-
" print(\"No high-quality MCQs could be generated from the provided context.\")\n",
|
334 |
-
" print(\"Try providing a longer, more detailed context with specific facts and entities.\")\n",
|
335 |
-
"\n",
|
336 |
-
" print(\"\\nGeneration complete!\")\n",
|
337 |
-
"\n",
|
338 |
-
"if __name__ == \"__main__\":\n",
|
339 |
-
" main()"
|
340 |
-
],
|
341 |
-
"metadata": {
|
342 |
-
"colab": {
|
343 |
-
"base_uri": "https://localhost:8080/"
|
344 |
-
},
|
345 |
-
"id": "o1ic84jCGc-u",
|
346 |
-
"outputId": "82f5601f-0e8f-4ca7-d9a4-b514a58df793"
|
347 |
-
},
|
348 |
-
"execution_count": null,
|
349 |
-
"outputs": [
|
350 |
-
{
|
351 |
-
"output_type": "stream",
|
352 |
-
"name": "stderr",
|
353 |
-
"text": [
|
354 |
-
"Device set to use cpu\n",
|
355 |
-
"Device set to use cpu\n"
|
356 |
-
]
|
357 |
-
},
|
358 |
-
{
|
359 |
-
"output_type": "stream",
|
360 |
-
"name": "stdout",
|
361 |
-
"text": [
|
362 |
-
"Multiple Choice Question Generator\n",
|
363 |
-
"Enter your context (or press Enter to use sample): India has numerous national parks dedicated to preserving wildlife and biodiversity. Some of the most famous include Jim Corbett National Park in Uttarakhand, known for tigers; Kaziranga National Park in Assam, home to the one-horned rhinoceros; and Sundarbans in West Bengal, famous for mangrove forests and Royal Bengal Tigers. These parks also support eco-tourism and help protect endangered species and fragile ecosystems.\n",
|
364 |
-
"Number of MCQs to generate (default 5): 4\n",
|
365 |
-
"\n",
|
366 |
-
"Generating 4 multiple choice questions...\n",
|
367 |
-
"\n",
|
368 |
-
"Question 1: \n",
|
369 |
-
"Q: What country has numerous national parks dedicated to preserving wildlife and biodiversity?\n",
|
370 |
-
"\n",
|
371 |
-
"A) Asia\n",
|
372 |
-
"B) India\n",
|
373 |
-
"C) Indian\n",
|
374 |
-
"D) Pakistan\n",
|
375 |
-
"\n",
|
376 |
-
"Correct Answer: B) India\n",
|
377 |
-
"\n",
|
378 |
-
"Question 2: \n",
|
379 |
-
"Q: What national park in Uttarakhand is known for tigers?\n",
|
380 |
-
"\n",
|
381 |
-
"A) Sanctuary\n",
|
382 |
-
"B) Leh\n",
|
383 |
-
"C) Jim Corbett National Park\n",
|
384 |
-
"D) Gir\n",
|
385 |
-
"\n",
|
386 |
-
"Correct Answer: C) Jim Corbett National Park\n",
|
387 |
-
"\n",
|
388 |
-
"Question 3: \n",
|
389 |
-
"Q: Where is Jim Corbett National Park located?\n",
|
390 |
-
"\n",
|
391 |
-
"A) Uttarakhand\n",
|
392 |
-
"B) Maharashtra\n",
|
393 |
-
"C) Kerala\n",
|
394 |
-
"D) Bengal\n",
|
395 |
-
"\n",
|
396 |
-
"Correct Answer: A) Uttarakhand\n",
|
397 |
-
"\n",
|
398 |
-
"Question 4: \n",
|
399 |
-
"Q: How many horned rhinoceros live in Kaziranga National Park?\n",
|
400 |
-
"\n",
|
401 |
-
"A) big\n",
|
402 |
-
"B) one\n",
|
403 |
-
"C) great\n",
|
404 |
-
"D) two\n",
|
405 |
-
"\n",
|
406 |
-
"Correct Answer: B) one\n",
|
407 |
-
"\n",
|
408 |
-
"Generation complete!\n"
|
409 |
-
]
|
410 |
-
}
|
411 |
-
]
|
412 |
-
},
|
413 |
-
{
|
414 |
-
"cell_type": "code",
|
415 |
-
"source": [],
|
416 |
-
"metadata": {
|
417 |
-
"id": "cR-tHf45INso"
|
418 |
-
},
|
419 |
-
"execution_count": null,
|
420 |
-
"outputs": []
|
421 |
-
}
|
422 |
-
]
|
423 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quiz_logic.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# quiz_logic.py
|
2 |
+
import random
|
3 |
+
import nltk
|
4 |
+
from transformers import pipeline
|
5 |
+
from nltk.tokenize import sent_tokenize
|
6 |
+
|
7 |
+
# Download required tokenizer
|
8 |
+
nltk.download('punkt', quiet=True)
|
9 |
+
|
10 |
+
# Load NLI model
|
11 |
+
nli = pipeline("text-classification", model="facebook/bart-large-mnli")
|
12 |
+
|
13 |
+
def validate_inputs(context, num_questions, difficulty):
|
14 |
+
if not context.strip():
|
15 |
+
return False, "Context cannot be empty."
|
16 |
+
sentences = sent_tokenize(context)
|
17 |
+
if len(sentences) < num_questions:
|
18 |
+
return False, f"Context has only {len(sentences)} sentences, but {num_questions} questions requested."
|
19 |
+
if difficulty not in ["easy", "medium", "hard"]:
|
20 |
+
return False, "Difficulty must be 'easy', 'medium', or 'hard'."
|
21 |
+
return True, sentences
|
22 |
+
|
23 |
+
def apply_noise(sentence: str, level: str) -> str:
|
24 |
+
if level == "easy":
|
25 |
+
return sentence
|
26 |
+
elif level == "medium":
|
27 |
+
if "Sun" in sentence:
|
28 |
+
return sentence.replace("Sun", "Moon")
|
29 |
+
return sentence.replace("is", "is not") if "is" in sentence else sentence
|
30 |
+
elif level == "hard":
|
31 |
+
if "eight" in sentence:
|
32 |
+
return sentence.replace("eight", "ten")
|
33 |
+
return sentence.replace("planets", "stars") if "planets" in sentence else sentence
|
34 |
+
return sentence
|
35 |
+
|
36 |
+
def generate_statements(context, n, difficulty, sentences):
|
37 |
+
random.seed(42)
|
38 |
+
selected = random.sample(sentences, min(n * 2, len(sentences)))
|
39 |
+
final = []
|
40 |
+
for s in selected:
|
41 |
+
clean = s.strip()
|
42 |
+
modified = apply_noise(clean, difficulty)
|
43 |
+
label = "ENTAILMENT" if clean == modified else "CONTRADICTION"
|
44 |
+
final.append({"statement": modified, "actual_label": label})
|
45 |
+
if len(final) >= n:
|
46 |
+
break
|
47 |
+
return final
|
48 |
+
|
49 |
+
def score_answers(context, answers):
|
50 |
+
score = 0
|
51 |
+
results = []
|
52 |
+
for answer in answers:
|
53 |
+
statement = answer.get('statement')
|
54 |
+
user_answer = answer.get('user_answer', '').strip().lower()
|
55 |
+
if user_answer not in ['true', 'false']:
|
56 |
+
results.append({
|
57 |
+
"statement": statement,
|
58 |
+
"result": "Invalid answer. Please use 'true' or 'false'."
|
59 |
+
})
|
60 |
+
continue
|
61 |
+
input_text = f"{context} [SEP] {statement}"
|
62 |
+
result = nli(input_text)[0]
|
63 |
+
if result["label"] == "neutral":
|
64 |
+
results.append({
|
65 |
+
"statement": statement,
|
66 |
+
"result": "Skipped due to ambiguous statement."
|
67 |
+
})
|
68 |
+
continue
|
69 |
+
model_label = "ENTAILMENT" if result["label"] == "entailment" else "CONTRADICTION"
|
70 |
+
is_correct = (model_label == "ENTAILMENT" and user_answer == "true") or \
|
71 |
+
(model_label == "CONTRADICTION" and user_answer == "false")
|
72 |
+
results.append({
|
73 |
+
"statement": statement,
|
74 |
+
"result": "Correct" if is_correct else f"Incorrect (Correct answer: {'True' if model_label == 'ENTAILMENT' else 'False'})"
|
75 |
+
})
|
76 |
+
if is_correct:
|
77 |
+
score += 1
|
78 |
+
return score, results
|
short_answer_generator.ipynb
DELETED
@@ -1,603 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"nbformat": 4,
|
3 |
-
"nbformat_minor": 0,
|
4 |
-
"metadata": {
|
5 |
-
"colab": {
|
6 |
-
"provenance": [],
|
7 |
-
"authorship_tag": "ABX9TyP9geFhx2LxpKcLjq1rwpK/",
|
8 |
-
"include_colab_link": true
|
9 |
-
},
|
10 |
-
"kernelspec": {
|
11 |
-
"name": "python3",
|
12 |
-
"display_name": "Python 3"
|
13 |
-
},
|
14 |
-
"language_info": {
|
15 |
-
"name": "python"
|
16 |
-
}
|
17 |
-
},
|
18 |
-
"cells": [
|
19 |
-
{
|
20 |
-
"cell_type": "markdown",
|
21 |
-
"metadata": {
|
22 |
-
"id": "view-in-github",
|
23 |
-
"colab_type": "text"
|
24 |
-
},
|
25 |
-
"source": [
|
26 |
-
"<a href=\"https://colab.research.google.com/github/DishaKushwah/custom-quiz-generator/blob/main/short_answer_generator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
27 |
-
]
|
28 |
-
},
|
29 |
-
{
|
30 |
-
"cell_type": "code",
|
31 |
-
"execution_count": 2,
|
32 |
-
"metadata": {
|
33 |
-
"colab": {
|
34 |
-
"base_uri": "https://localhost:8080/"
|
35 |
-
},
|
36 |
-
"id": "M8kqiMR4-nfk",
|
37 |
-
"outputId": "aef8799d-e2d0-4868-a31a-23f9f58d7bd7"
|
38 |
-
},
|
39 |
-
"outputs": [
|
40 |
-
{
|
41 |
-
"output_type": "stream",
|
42 |
-
"name": "stdout",
|
43 |
-
"text": [
|
44 |
-
"Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.53.1)\n",
|
45 |
-
"Collecting transformers\n",
|
46 |
-
" Downloading transformers-4.53.2-py3-none-any.whl.metadata (40 kB)\n",
|
47 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.9/40.9 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
48 |
-
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.18.0)\n",
|
49 |
-
"Requirement already satisfied: huggingface-hub<1.0,>=0.30.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.33.2)\n",
|
50 |
-
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2.0.2)\n",
|
51 |
-
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.2)\n",
|
52 |
-
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
|
53 |
-
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n",
|
54 |
-
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n",
|
55 |
-
"Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.21.2)\n",
|
56 |
-
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.3)\n",
|
57 |
-
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n",
|
58 |
-
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (2025.3.2)\n",
|
59 |
-
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (4.14.1)\n",
|
60 |
-
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.30.0->transformers) (1.1.5)\n",
|
61 |
-
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.4.2)\n",
|
62 |
-
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n",
|
63 |
-
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.4.0)\n",
|
64 |
-
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2025.7.9)\n",
|
65 |
-
"Downloading transformers-4.53.2-py3-none-any.whl (10.8 MB)\n",
|
66 |
-
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m105.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
67 |
-
"\u001b[?25hInstalling collected packages: transformers\n",
|
68 |
-
" Attempting uninstall: transformers\n",
|
69 |
-
" Found existing installation: transformers 4.53.1\n",
|
70 |
-
" Uninstalling transformers-4.53.1:\n",
|
71 |
-
" Successfully uninstalled transformers-4.53.1\n",
|
72 |
-
"Successfully installed transformers-4.53.2\n"
|
73 |
-
]
|
74 |
-
}
|
75 |
-
],
|
76 |
-
"source": [
|
77 |
-
"## SHORT ANSWERS\n",
|
78 |
-
"%pip install --upgrade transformers\n",
|
79 |
-
"import torch\n",
|
80 |
-
"from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForQuestionAnswering, pipeline, T5ForConditionalGeneration, T5Tokenizer)\n",
|
81 |
-
"import spacy\n",
|
82 |
-
"import numpy as np\n",
|
83 |
-
"from sentence_transformers import SentenceTransformer\n",
|
84 |
-
"import re\n",
|
85 |
-
"import nltk\n",
|
86 |
-
"from nltk.tokenize import sent_tokenize\n",
|
87 |
-
"from typing import List, Dict, Tuple, Optional\n",
|
88 |
-
"import random\n",
|
89 |
-
"from dataclasses import dataclass\n",
|
90 |
-
"import json"
|
91 |
-
]
|
92 |
-
},
|
93 |
-
{
|
94 |
-
"cell_type": "code",
|
95 |
-
"source": [
|
96 |
-
"@dataclass\n",
|
97 |
-
"class ShortAnswerQuestion:\n",
|
98 |
-
" question: str\n",
|
99 |
-
" answer: str\n",
|
100 |
-
" context_sentence: str\n",
|
101 |
-
" question_type: str\n",
|
102 |
-
" difficulty: str\n",
|
103 |
-
" confidence: float\n",
|
104 |
-
" keywords: List[str]\n",
|
105 |
-
" expected_length: str\n",
|
106 |
-
"\n",
|
107 |
-
"class AdvancedShortAnswerGenerator:\n",
|
108 |
-
" def __init__(self):\n",
|
109 |
-
" \"\"\"Initialize with state-of-the-art models for question generation.\"\"\"\n",
|
110 |
-
" self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
111 |
-
" # Load the best question generation model - T5-large fine-tuned for QG\n",
|
112 |
-
" self.qg_model_name = \"valhalla/t5-base-qg-hl\"\n",
|
113 |
-
"\n",
|
114 |
-
" # Use AutoTokenizer and AutoModelForSeq2SeqLM for broader compatibility\n",
|
115 |
-
" self.qg_tokenizer = AutoTokenizer.from_pretrained(self.qg_model_name)\n",
|
116 |
-
" self.qg_model = AutoModelForSeq2SeqLM.from_pretrained(self.qg_model_name).to(self.device)\n",
|
117 |
-
"\n",
|
118 |
-
" # Load FLAN-T5 for better question generation\n",
|
119 |
-
" self.flan_model_name = \"google/flan-t5-base\"\n",
|
120 |
-
" self.flan_tokenizer = AutoTokenizer.from_pretrained(self.flan_model_name)\n",
|
121 |
-
" self.flan_model = AutoModelForSeq2SeqLM.from_pretrained(self.flan_model_name).to(self.device)\n",
|
122 |
-
"\n",
|
123 |
-
" # Load DeBERta for high-quality answer extraction\n",
|
124 |
-
" self.qa_model_name = \"deepset/roberta-base-squad2\"\n",
|
125 |
-
" self.qa_tokenizer = AutoTokenizer.from_pretrained(self.qa_model_name)\n",
|
126 |
-
" self.qa_model = AutoModelForQuestionAnswering.from_pretrained(self.qa_model_name).to(self.device)\n",
|
127 |
-
"\n",
|
128 |
-
" # Load sentence transformer for semantic analysis\n",
|
129 |
-
" self.sentence_model = SentenceTransformer('all-mpnet-base-v2')\n",
|
130 |
-
"\n",
|
131 |
-
" # Load spaCy for advanced NLP\n",
|
132 |
-
" try:\n",
|
133 |
-
" self.nlp = spacy.load(\"en_core_web_sm\")\n",
|
134 |
-
" except OSError:\n",
|
135 |
-
" print(\"Please install spaCy English model: python -m spacy download en_core_web_sm\")\n",
|
136 |
-
" self.nlp = None\n",
|
137 |
-
" try:\n",
|
138 |
-
" nltk.download('punkt', quiet=True)\n",
|
139 |
-
" nltk.download('stopwords', quiet=True)\n",
|
140 |
-
" nltk.download('averaged_perceptron_tagger', quiet=True)\n",
|
141 |
-
" nltk.download('punkt_tab', quiet=True) # Added download for punkt_tab\n",
|
142 |
-
" except:\n",
|
143 |
-
" pass\n",
|
144 |
-
" # Question type templates\n",
|
145 |
-
" self.question_templates = {\n",
|
146 |
-
" 'factual': [\"What is {}?\",\"What does {} mean?\",\"What are the characteristics of {}?\",\"Define {}.\",\"Explain {}.\"],\n",
|
147 |
-
" 'analytical': [\"How does {} work?\",\"Why is {} important?\",\"What is the significance of {}?\",\"How does {} relate to {}?\",\"What are the implications of {}?\"],\n",
|
148 |
-
" 'comparative': [\"Compare {} and {}.\",\"What are the differences between {} and {}?\",\"How does {} differ from {}?\",\"What are the similarities between {} and {}?\"],\n",
|
149 |
-
" 'causal': [\"What caused {}?\",\"What are the effects of {}?\",\"How did {} lead to {}?\",\"What resulted from {}?\"],\n",
|
150 |
-
" 'procedural': [\"How do you {}?\",\"What are the steps to {}?\",\"Describe the process of {}.\",\"What is the procedure for {}?\"]}\n",
|
151 |
-
"\n",
|
152 |
-
" def extract_key_concepts(self, text: str) -> Dict:\n",
|
153 |
-
" \"\"\"Extract key concepts and entities from text.\"\"\"\n",
|
154 |
-
" if not self.nlp:\n",
|
155 |
-
" return {\"entities\": [], \"concepts\": [], \"sentences\": []}\n",
|
156 |
-
" doc = self.nlp(text)\n",
|
157 |
-
"\n",
|
158 |
-
" entities = []\n",
|
159 |
-
" for ent in doc.ents:\n",
|
160 |
-
" # Include more entity types for broader question generation\n",
|
161 |
-
" if ent.label_ in ['PERSON', 'ORG', 'GPE', 'EVENT', 'WORK_OF_ART', 'LAW', 'LANGUAGE', 'DATE', 'CARDINAL', 'ORDINAL', 'NORP', 'LOC', 'PRODUCT']:\n",
|
162 |
-
" entities.append({'text': ent.text,'label': ent.label_,'start': ent.start_char,'end': ent.end_char})\n",
|
163 |
-
"\n",
|
164 |
-
" # Extract key concepts (noun phrases)\n",
|
165 |
-
" concepts = []\n",
|
166 |
-
" for chunk in doc.noun_chunks:\n",
|
167 |
-
" # Adjust length for slightly longer concepts\n",
|
168 |
-
" if 2 <= len(chunk.text.split()) <= 5:\n",
|
169 |
-
" concepts.append({'text': chunk.text,'pos': chunk.root.pos_,'start': chunk.start_char,'end': chunk.end_char})\n",
|
170 |
-
"\n",
|
171 |
-
" # Extract sentences\n",
|
172 |
-
" sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.split()) >= 10] # Increased minimum sentence length\n",
|
173 |
-
" return {\"entities\": entities,\"concepts\": concepts,\"sentences\": sentences}\n",
|
174 |
-
"\n",
|
175 |
-
" def generate_question_with_t5(self, context: str, answer: str, difficulty: str = \"medium\", question_type: str = \"factual\") -> str:\n",
|
176 |
-
" \"\"\"Generate question using T5 model with prepend approach, considering difficulty.\"\"\"\n",
|
177 |
-
" # Incorporate difficulty into the prompt\n",
|
178 |
-
" prompt_prefix = f\"generate {difficulty} {question_type} question:\"\n",
|
179 |
-
" input_text = f\"{prompt_prefix} context: {context} \\\\n {answer}\"\n",
|
180 |
-
" inputs = self.qg_tokenizer.encode_plus(input_text,max_length=512,truncation=True,padding=True,return_tensors=\"pt\").to(self.device)\n",
|
181 |
-
"\n",
|
182 |
-
" # Adjust generation parameters based on difficulty (simple heuristic)\n",
|
183 |
-
" max_length = 100\n",
|
184 |
-
" num_beams = 5\n",
|
185 |
-
" temperature = 0.7\n",
|
186 |
-
" if difficulty == \"easy\":\n",
|
187 |
-
" max_length = 80\n",
|
188 |
-
" temperature = 0.6\n",
|
189 |
-
" elif difficulty == \"hard\":\n",
|
190 |
-
" max_length = 120\n",
|
191 |
-
" temperature = 0.9\n",
|
192 |
-
" num_beams = 8\n",
|
193 |
-
"\n",
|
194 |
-
" with torch.no_grad():\n",
|
195 |
-
" outputs = self.qg_model.generate(inputs[\"input_ids\"],attention_mask=inputs[\"attention_mask\"],max_length=max_length,num_beams=num_beams,temperature=temperature,do_sample=True,early_stopping=True,no_repeat_ngram_size=2)\n",
|
196 |
-
" question = self.qg_tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
197 |
-
" return question.strip()\n",
|
198 |
-
"\n",
|
199 |
-
" def generate_question_with_flan(self, context: str, answer: str, difficulty: str = \"medium\", question_type: str = \"factual\") -> str:\n",
|
200 |
-
" \"\"\"Generate question using FLAN-T5 model, considering difficulty.\"\"\"\n",
|
201 |
-
" # Incorporate difficulty into the prompt\n",
|
202 |
-
" prompt = f\"\"\"Given the following context, generate a concise {difficulty}-level short answer question where the answer is '{answer}':\n",
|
203 |
-
"Context: {context}\n",
|
204 |
-
"Question:\"\"\"\n",
|
205 |
-
" inputs = self.flan_tokenizer(prompt,max_length=512,truncation=True,padding=True,return_tensors=\"pt\").to(self.device)\n",
|
206 |
-
"\n",
|
207 |
-
" # Adjust generation parameters based on difficulty (simple heuristic)\n",
|
208 |
-
" max_length = 150\n",
|
209 |
-
" num_beams = 4\n",
|
210 |
-
" temperature = 0.8\n",
|
211 |
-
" if difficulty == \"easy\":\n",
|
212 |
-
" max_length = 100\n",
|
213 |
-
" temperature = 0.6\n",
|
214 |
-
" elif difficulty == \"hard\":\n",
|
215 |
-
" max_length = 200\n",
|
216 |
-
" temperature = 0.9\n",
|
217 |
-
" num_beams = 6\n",
|
218 |
-
" with torch.no_grad():\n",
|
219 |
-
" outputs = self.flan_model.generate(inputs[\"input_ids\"],attention_mask=inputs[\"attention_mask\"],max_length=max_length,num_beams=num_beams,temperature=temperature,do_sample=True,early_stopping=True)\n",
|
220 |
-
"\n",
|
221 |
-
" question = self.flan_tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
|
222 |
-
" return question.strip()\n",
|
223 |
-
"\n",
|
224 |
-
" def classify_question_difficulty(self, question: str, answer: str, context: str) -> str:\n",
|
225 |
-
" \"\"\"Classify question difficulty based on complexity and context.\"\"\"\n",
|
226 |
-
" if not self.nlp:\n",
|
227 |
-
" return \"medium\" # Default to medium if spaCy is not loaded\n",
|
228 |
-
"\n",
|
229 |
-
" question_lower = question.lower()\n",
|
230 |
-
" question_words = question_lower.split()\n",
|
231 |
-
" answer_words = answer.lower().split()\n",
|
232 |
-
"\n",
|
233 |
-
" # Keyword indicators\n",
|
234 |
-
" easy_keywords = ['what', 'who', 'when', 'where', 'name', 'list', 'define']\n",
|
235 |
-
" medium_keywords = ['how', 'why', 'explain', 'describe', 'role', 'purpose']\n",
|
236 |
-
" hard_keywords = ['analyze', 'evaluate', 'synthesize', 'impact', 'implication', 'relationship']\n",
|
237 |
-
" easy_score = sum(1 for word in easy_keywords if word in question_words)\n",
|
238 |
-
" medium_score = sum(1 for word in medium_keywords if word in question_words)\n",
|
239 |
-
" hard_score = sum(1 for word in hard_keywords if word in question_words)\n",
|
240 |
-
"\n",
|
241 |
-
" # Answer length\n",
|
242 |
-
" answer_length_score = 0\n",
|
243 |
-
" if len(answer_words) > 15:\n",
|
244 |
-
" answer_length_score = 3\n",
|
245 |
-
" elif len(answer_words) > 8:\n",
|
246 |
-
" answer_length_score = 2\n",
|
247 |
-
" elif len(answer_words) > 3:\n",
|
248 |
-
" answer_length_score = 1\n",
|
249 |
-
"\n",
|
250 |
-
" # Context complexity (simple measure: average sentence length)\n",
|
251 |
-
" sentences = sent_tokenize(context)\n",
|
252 |
-
" avg_sentence_length = np.mean([len(s.split()) for s in sentences]) if sentences else 0\n",
|
253 |
-
"\n",
|
254 |
-
" context_complexity_score = 0\n",
|
255 |
-
" if avg_sentence_length > 25:\n",
|
256 |
-
" context_complexity_score = 2\n",
|
257 |
-
" elif avg_sentence_length > 18:\n",
|
258 |
-
" context_complexity_score = 1\n",
|
259 |
-
" if self.nlp:\n",
|
260 |
-
" doc_question = self.nlp(question)\n",
|
261 |
-
" doc_answer = self.nlp(answer)\n",
|
262 |
-
" # 1. Part-of-speech tagging (weights based on complexity)\n",
|
263 |
-
" # Corrected way to get POS counts\n",
|
264 |
-
" pos_counts_question = {}\n",
|
265 |
-
" for token in doc_question:\n",
|
266 |
-
" pos_counts_question[token.pos_] = pos_counts_question.get(token.pos_, 0) +1\n",
|
267 |
-
" pos_counts_answer = {}\n",
|
268 |
-
" for token in doc_answer:\n",
|
269 |
-
" pos_counts_answer[token.pos_] = pos_counts_answer.get(token.pos_, 0) +1\n",
|
270 |
-
" pos_score = (\n",
|
271 |
-
" pos_counts_question.get(spacy.parts_of_speech.ADJ, 0) * 0.6 + # Further Increased weight for Adjectives\n",
|
272 |
-
" pos_counts_question.get(spacy.parts_of_speech.ADV, 0) * 0.7 + # Further Increased weight for Adverbs\n",
|
273 |
-
" pos_counts_question.get(spacy.parts_of_speech.VERB, 0) * 0.5 + # Further Increased weight for Verbs\n",
|
274 |
-
" pos_counts_answer.get(spacy.parts_of_speech.ADJ, 0) * 0.5 +\n",
|
275 |
-
" pos_counts_answer.get(spacy.parts_of_speech.NOUN, 0) * 0.4 )# Further Increased weight for Nouns in answer\n",
|
276 |
-
"\n",
|
277 |
-
" # 2. Dependency parsing complexity (simple measure: average dependency depth) - Higher depth means more complex syntax\n",
|
278 |
-
" dep_depths_question = [len(list(token.ancestors)) for token in doc_question]\n",
|
279 |
-
" avg_dep_depth_question = np.mean(dep_depths_question) if dep_depths_question else 0\n",
|
280 |
-
" dep_score = avg_dep_depth_question * 1.2 # Significantly increased weigh\n",
|
281 |
-
"\n",
|
282 |
-
" # 3. Named entity recognition - More entities can indicate more specific/complex questions\n",
|
283 |
-
" num_entities_question = len(doc_question.ents)\n",
|
284 |
-
" num_entities_answer = len(doc_answer.ents)\n",
|
285 |
-
" entity_score = (num_entities_question * 1.2 + num_entities_answer * 1.5) # Significantly increased weight for entitie\n",
|
286 |
-
"\n",
|
287 |
-
" # 4. Lexical diversity (Type-Token Ratio) - Lower TTR might indicate simpler language, higher TTR more complex\n",
|
288 |
-
" question_tokens = [token.text.lower() for token in doc_question if token.is_alpha]\n",
|
289 |
-
" answer_tokens = [token.text.lower() for token in doc_answer if token.is_alpha]\n",
|
290 |
-
" question_ttr = len(set(question_tokens)) / len(question_tokens) if question_tokens else 0\n",
|
291 |
-
" answer_ttr = len(set(answer_tokens)) / len(answer_tokens) if answer_tokens else 0\n",
|
292 |
-
"\n",
|
293 |
-
" # Inverse TTR for scoring (lower TTR = higher score for easy, higher TTR = higher score for hard)\n",
|
294 |
-
" ttr_score = (question_ttr * 3.0 + answer_ttr * 2.5) # Significantly increased weigh\n",
|
295 |
-
" # Combine linguistic features into a single score\n",
|
296 |
-
" linguistic_score = pos_score + dep_score + entity_score + ttr_score\n",
|
297 |
-
" else:\n",
|
298 |
-
" linguistic_score = 0\n",
|
299 |
-
"\n",
|
300 |
-
" # Combine all scores with adjusted weights\n",
|
301 |
-
" total_score = (hard_score * 7 + medium_score * 4 + easy_score * 1.5 + answer_length_score * 3.0 + context_complexity_score * 3.0 + linguistic_score * 2.5 )\n",
|
302 |
-
"\n",
|
303 |
-
" # Refined thresholds based on adjusted scoring\n",
|
304 |
-
" # These thresholds will likely need tuning based on testing\n",
|
305 |
-
" if total_score > 28: # Adjusted thresholds slightly down\n",
|
306 |
-
" return \"hard\"\n",
|
307 |
-
" elif total_score > 14: # Adjusted thresholds slightly down\n",
|
308 |
-
" return \"medium\"\n",
|
309 |
-
" else:\n",
|
310 |
-
" return \"easy\"\n",
|
311 |
-
"\n",
|
312 |
-
" def determine_question_type(self, question: str) -> str:\n",
|
313 |
-
" \"\"\"Determine the type of question based on its content.\"\"\"\n",
|
314 |
-
" question_lower = question.lower()\n",
|
315 |
-
"\n",
|
316 |
-
" if any(word in question_lower for word in ['what is', 'what are', 'define', 'who is', 'who are', 'when is', 'when did', 'where is', 'where are']):\n",
|
317 |
-
" return \"factual\"\n",
|
318 |
-
" elif any(word in question_lower for word in ['how does', 'how to', 'why is', 'why do', 'explain', 'describe']):\n",
|
319 |
-
" return \"analytical\"\n",
|
320 |
-
" elif any(word in question_lower for word in ['compare', 'contrast', 'differ', 'similarities', 'differences']):\n",
|
321 |
-
" return \"comparative\"\n",
|
322 |
-
" elif any(word in question_lower for word in ['cause', 'effect', 'result', 'lead to', 'consequence']):\n",
|
323 |
-
" return \"causal\"\n",
|
324 |
-
" elif any(word in question_lower for word in ['steps', 'process', 'procedure', 'how to']):\n",
|
325 |
-
" return \"procedural\"\n",
|
326 |
-
" else:\n",
|
327 |
-
" return \"factual\"\n",
|
328 |
-
"\n",
|
329 |
-
" def extract_keywords(self, text: str) -> List[str]:\n",
|
330 |
-
" \"\"\"Extract keywords from text using NLP.\"\"\"\n",
|
331 |
-
" if not self.nlp:\n",
|
332 |
-
" return []\n",
|
333 |
-
"\n",
|
334 |
-
" doc = self.nlp(text)\n",
|
335 |
-
" keywords = []\n",
|
336 |
-
" stopwords = set(nltk.corpus.stopwords.words('english'))\n",
|
337 |
-
"\n",
|
338 |
-
" for token in doc:\n",
|
339 |
-
" if (token.pos_ in ['NOUN', 'PROPN', 'ADJ', 'VERB'] and token.text.lower() not in stopwords and not token.is_punct and len(token.text) > 2):\n",
|
340 |
-
" keywords.append(token.text)\n",
|
341 |
-
"\n",
|
342 |
-
" # Prioritize multi-word concepts if they exist\n",
|
343 |
-
" multi_word_keywords = [chunk.text for chunk in doc.noun_chunks if len(chunk.text.split()) > 1 and len(chunk.text.split()) <= 3]\n",
|
344 |
-
" keywords = multi_word_keywords + keywords\n",
|
345 |
-
" return list(set(keywords))\n",
|
346 |
-
"\n",
|
347 |
-
" def validate_question_answer_pair(self, question: str, expected_answer: str, context: str) -> Dict:\n",
|
348 |
-
" \"\"\"Validate if the question can be answered correctly from the context using the QA model.\"\"\"\n",
|
349 |
-
" try:\n",
|
350 |
-
" # Use QA model and tokenizer explicitly\n",
|
351 |
-
" inputs = self.qa_tokenizer(question, context, add_special_tokens=True, return_tensors=\"pt\", truncation=True, max_length=512)\n",
|
352 |
-
" input_ids = inputs[\"input_ids\"].to(self.device)\n",
|
353 |
-
" attention_mask = inputs[\"attention_mask\"].to(self.device)\n",
|
354 |
-
"\n",
|
355 |
-
" with torch.no_grad():\n",
|
356 |
-
" outputs = self.qa_model(input_ids=input_ids, attention_mask=inputs[\"attention_mask\"])\n",
|
357 |
-
"\n",
|
358 |
-
" answer_start_scores = outputs.start_logits\n",
|
359 |
-
" answer_end_scores = outputs.end_logits\n",
|
360 |
-
"\n",
|
361 |
-
" # Get the most likely answer span\n",
|
362 |
-
" answer_start = torch.argmax(answer_start_scores)\n",
|
363 |
-
" answer_end = torch.argmax(answer_end_scores) + 1\n",
|
364 |
-
"\n",
|
365 |
-
" # Convert tokens to predicted answer string\n",
|
366 |
-
" predicted_answer = self.qa_tokenizer.decode(input_ids[0, answer_start:answer_end], skip_special_tokens=True)\n",
|
367 |
-
"\n",
|
368 |
-
" # Calculate a confidence score (using max of start and end logits)\n",
|
369 |
-
" confidence = (torch.max(torch.softmax(answer_start_scores, dim=-1)) + torch.max(torch.softmax(answer_end_scores, dim=-1))) / 2.0\n",
|
370 |
-
"\n",
|
371 |
-
" # Calculate semantic similarity between expected and predicted answers\n",
|
372 |
-
" # Handle potential errors if encoding fails\n",
|
373 |
-
" try:\n",
|
374 |
-
" expected_embedding = self.sentence_model.encode([expected_answer])\n",
|
375 |
-
" predicted_embedding = self.sentence_model.encode([predicted_answer])\n",
|
376 |
-
" similarity = np.dot(expected_embedding[0], predicted_embedding[0]) / (np.linalg.norm(expected_embedding[0]) * np.linalg.norm(predicted_embedding[0]))\n",
|
377 |
-
" except Exception as e:\n",
|
378 |
-
" print(f\"Error encoding answers for similarity: {e}\")\n",
|
379 |
-
" similarity = 0.0 # Default to 0 similarity on error\n",
|
380 |
-
"\n",
|
381 |
-
" # Check if answers are semantically similar or one contains the other\n",
|
382 |
-
" contains_check = (expected_answer.lower().strip() in predicted_answer.lower().strip() or predicted_answer.lower().strip() in expected_answer.lower().strip())\n",
|
383 |
-
"\n",
|
384 |
-
" # Consider similarity and containment for validation\n",
|
385 |
-
" is_valid = (similarity > 0.7 and confidence > 0.4) or (contains_check and confidence > 0.5)\n",
|
386 |
-
" return {\"is_valid\": is_valid,\"confidence\": confidence.item(),\"similarity\": similarity,\"predicted_answer\": predicted_answer,\"expected_answer\": expected_answer}\n",
|
387 |
-
"\n",
|
388 |
-
" except Exception as e:\n",
|
389 |
-
" # Catch specific errors from pipeline if possible\n",
|
390 |
-
" print(f\"Error during QA validation: {e}\")\n",
|
391 |
-
" return {\"is_valid\": False,\"confidence\": 0.0,\"similarity\": 0.0,\"predicted_answer\": \"\",\"expected_answer\": expected_answer,\"error\": str(e)}\n",
|
392 |
-
"\n",
|
393 |
-
" def determine_expected_length(self, answer: str) -> str:\n",
|
394 |
-
" \"\"\"Determine expected answer length category based on word count.\"\"\"\n",
|
395 |
-
" word_count = len(answer.split())\n",
|
396 |
-
"\n",
|
397 |
-
" if word_count <= 5:\n",
|
398 |
-
" return \"brief (few words)\"\n",
|
399 |
-
" elif word_count <= 15:\n",
|
400 |
-
" return \"short (1-2 sentences)\"\n",
|
401 |
-
" elif word_count <= 30:\n",
|
402 |
-
" return \"medium (2-4 sentences)\"\n",
|
403 |
-
" else:\n",
|
404 |
-
" return \"long (paragraph+)\"\n",
|
405 |
-
"\n",
|
406 |
-
" def generate_comprehensive_questions(self, context: str, num_questions: int = 8, difficulty: str = \"medium\") -> List[ShortAnswerQuestion]:\n",
|
407 |
-
" \"\"\"Generate comprehensive set of short answer questions, considering difficulty.\"\"\"\n",
|
408 |
-
" questions = []\n",
|
409 |
-
" generated_pairs = set() # To avoid duplicate question-answer pairs\n",
|
410 |
-
"\n",
|
411 |
-
" # Extract key information\n",
|
412 |
-
" key_info = self.extract_key_concepts(context)\n",
|
413 |
-
"\n",
|
414 |
-
" # Combine potential answers from entities and concepts\n",
|
415 |
-
" potential_answers = [e['text'] for e in key_info['entities']] + [c['text'] for c in key_info['concepts']]\n",
|
416 |
-
" random.shuffle(potential_answers) # Shuffle to mix entity and concept based questions\n",
|
417 |
-
"\n",
|
418 |
-
" # Determine number of attempts per answer based on difficulty\n",
|
419 |
-
" attempts_per_answer_map = {\"easy\": 4, \"medium\": 8, \"hard\": 12} # Increased attempts for all difficulties again\n",
|
420 |
-
" attempts_per_answer = attempts_per_answer_map.get(difficulty, 8)\n",
|
421 |
-
"\n",
|
422 |
-
" answers_processed = 0\n",
|
423 |
-
" for answer in potential_answers:\n",
|
424 |
-
" if len(questions) >= num_questions:\n",
|
425 |
-
" break\n",
|
426 |
-
" answers_processed += 1\n",
|
427 |
-
" if answers_processed > num_questions * 20: # Further increased limit to try more answers\n",
|
428 |
-
" print(f\"Reached maximum answer processing attempts ({num_questions * 20}). Stopping.\")\n",
|
429 |
-
" break\n",
|
430 |
-
" for attempt in range(attempts_per_answer):\n",
|
431 |
-
" if len(questions) >= num_questions:\n",
|
432 |
-
" break\n",
|
433 |
-
"\n",
|
434 |
-
" # Choose which model to use (can alternate or use both)\n",
|
435 |
-
" if attempt % 2 == 0:\n",
|
436 |
-
" question = self.generate_question_with_t5(context, answer, difficulty=difficulty)\n",
|
437 |
-
" else:\n",
|
438 |
-
" question = self.generate_question_with_flan(context, answer, difficulty=difficulty)\n",
|
439 |
-
"\n",
|
440 |
-
" # Basic cleaning and validation before full QA check\n",
|
441 |
-
" question = question.strip()\n",
|
442 |
-
" if not question or not question.endswith('?') or len(question.split()) < 5:\n",
|
443 |
-
" continue\n",
|
444 |
-
"\n",
|
445 |
-
" # Ensure question is unique\n",
|
446 |
-
" q_a_pair = (question, answer)\n",
|
447 |
-
" if q_a_pair in generated_pairs:\n",
|
448 |
-
" continue\n",
|
449 |
-
"\n",
|
450 |
-
" # Validate question-answer pair\n",
|
451 |
-
" validation = self.validate_question_answer_pair(question, answer, context)\n",
|
452 |
-
"\n",
|
453 |
-
" # Filtering based on difficulty and validation confidence\n",
|
454 |
-
" confidence_threshold = 0.3 # Base threshold\n",
|
455 |
-
"\n",
|
456 |
-
" # Adjust confidence threshold based on requested difficulty\n",
|
457 |
-
" if difficulty == \"easy\":\n",
|
458 |
-
" confidence_threshold = 0.25 #lower threshold for easy questions\n",
|
459 |
-
" elif difficulty == \"hard\":\n",
|
460 |
-
" confidence_threshold = 0.35 #higher threshold for hard questions\n",
|
461 |
-
"\n",
|
462 |
-
"\n",
|
463 |
-
" if validation[\"is_valid\"] and validation[\"confidence\"] > confidence_threshold: # confidence threshold\n",
|
464 |
-
" #type and classified difficulty\n",
|
465 |
-
" question_type = self.determine_question_type(question)\n",
|
466 |
-
" # Check if nlp is loaded before classifying difficulty\n",
|
467 |
-
" if self.nlp:\n",
|
468 |
-
" classified_difficulty = self.classify_question_difficulty(question, answer, context) # Classify generated question's actual difficulty\n",
|
469 |
-
" else:\n",
|
470 |
-
" classified_difficulty = \"medium\" # Default to medium if spaCy not loaded\n",
|
471 |
-
"\n",
|
472 |
-
" # Add the question if its classified difficulty is the requested one or one level below\n",
|
473 |
-
" # This allows some flexibility while aiming for the target difficulty\n",
|
474 |
-
" difficulty_levels = [\"easy\", \"medium\", \"hard\"]\n",
|
475 |
-
" requested_index = difficulty_levels.index(difficulty)\n",
|
476 |
-
" classified_index = difficulty_levels.index(classified_difficulty)\n",
|
477 |
-
"\n",
|
478 |
-
" # Accept if classified difficulty is at or one level below requested difficulty\n",
|
479 |
-
" if classified_index >= requested_index or (requested_index > 0 and classified_index == requested_index - 1):\n",
|
480 |
-
"\n",
|
481 |
-
" keywords = self.extract_keywords(f\"{question} {answer}\")\n",
|
482 |
-
" expected_length = self.determine_expected_length(answer)\n",
|
483 |
-
"\n",
|
484 |
-
" saq = ShortAnswerQuestion(question=question,answer=answer,context_sentence=context[:200] + \"...\" if len(context) > 200 else context,question_type=question_type,difficulty=classified_difficulty,confidence=validation[\"confidence\"],keywords=keywords[:5],expected_length=expected_length)\n",
|
485 |
-
" questions.append(saq)\n",
|
486 |
-
" generated_pairs.add(q_a_pair) # Add to history\n",
|
487 |
-
" # If we found a question of the requested difficulty, move to the next answer\n",
|
488 |
-
" if classified_difficulty == difficulty:\n",
|
489 |
-
" break\n",
|
490 |
-
" # Fallback: Generate questions directly from sentences if not enough generated\n",
|
491 |
-
" if len(questions) < num_questions:\n",
|
492 |
-
" print(f\"Warning: Could not generate {num_questions} questions of the requested difficulty. Adding fallback questions.\")\n",
|
493 |
-
" for sentence in key_info[\"sentences\"]:\n",
|
494 |
-
" if len(questions) >= num_questions:\n",
|
495 |
-
" break\n",
|
496 |
-
"\n",
|
497 |
-
" # Generate a question based on the sentence (can use template or model)\n",
|
498 |
-
" # Simple template fallback\n",
|
499 |
-
" question = f\"What is discussed in the sentence: \\\"{sentence[:50]}...\\\"?\"\n",
|
500 |
-
" answer = sentence # The sentence itself is the \"answer\" in this case\n",
|
501 |
-
"\n",
|
502 |
-
" # Validate (less strict for fallback)\n",
|
503 |
-
" validation = self.validate_question_answer_pair(question, answer, context)\n",
|
504 |
-
"\n",
|
505 |
-
" # Even if not perfectly valid, added as a fallback if needed and unique\n",
|
506 |
-
" q_a_pair = (question, answer)\n",
|
507 |
-
" if q_a_pair not in generated_pairs:\n",
|
508 |
-
" difficulty = \"easy\" # Fallback questions are usually easy\n",
|
509 |
-
" question_type = \"factual\"\n",
|
510 |
-
" keywords = self.extract_keywords(sentence)[:5]\n",
|
511 |
-
" expected_length = self.determine_expected_length(answer)\n",
|
512 |
-
"\n",
|
513 |
-
" saq = ShortAnswerQuestion( question=question, answer=\"Key points from the sentence.\", context_sentence=sentence, question_type=question_type, difficulty=difficulty, confidence=validation[\"confidence\"] if validation[\"is_valid\"] else 0.1, keywords=keywords, expected_length=\"short (1-2 sentences)\")\n",
|
514 |
-
" questions.append(saq)\n",
|
515 |
-
" generated_pairs.add(q_a_pair)\n",
|
516 |
-
"\n",
|
517 |
-
" # Sort by confidence (or potentially by classified difficulty later) and return\n",
|
518 |
-
" questions.sort(key=lambda x: x.confidence, reverse=True)\n",
|
519 |
-
" return questions[:num_questions]\n",
|
520 |
-
"\n",
|
521 |
-
"def main():\n",
|
522 |
-
" \"\"\"Main function to demonstrate the advanced SAQ generator.\"\"\"\n",
|
523 |
-
" generator = AdvancedShortAnswerGenerator()\n",
|
524 |
-
"\n",
|
525 |
-
" print(\" Short Answer Question Generator\")\n",
|
526 |
-
" # Get user input\n",
|
527 |
-
" user_context = input(\"Enter your context (or press Enter to use sample): \").strip()\n",
|
528 |
-
"\n",
|
529 |
-
" try:\n",
|
530 |
-
" num_questions = int(input(\"Number of questions to generate (default 6): \") or \"6\")\n",
|
531 |
-
" except ValueError:\n",
|
532 |
-
" num_questions = 6\n",
|
533 |
-
" print(f\"\\nGenerating {num_questions} short answer questions...\")\n",
|
534 |
-
"\n",
|
535 |
-
" # Generate questions, passing the difficulty\n",
|
536 |
-
" questions = generator.generate_comprehensive_questions(user_context, num_questions)\n",
|
537 |
-
"\n",
|
538 |
-
" # Display results\n",
|
539 |
-
" if questions:\n",
|
540 |
-
" for i, q in enumerate(questions, 1):\n",
|
541 |
-
" print(f\"\\nQuestion {i}: [CLASSIFIED: {q.difficulty.upper()}] ({q.question_type})\") # Display classified difficulty\n",
|
542 |
-
" print(f\"Q: {q.question}\")\n",
|
543 |
-
" print(f\"A: {q.answer}\")\n",
|
544 |
-
" print(f\"Expected Length: {q.expected_length}\")\n",
|
545 |
-
" else:\n",
|
546 |
-
" print(\"No high-quality questions could be generated from the provided context.\")\n",
|
547 |
-
" print(\"Try providing a longer, more detailed context with specific information.\")\n",
|
548 |
-
"\n",
|
549 |
-
" print(\"\\nGeneration complete!\")\n",
|
550 |
-
"\n",
|
551 |
-
"if __name__ == \"__main__\":\n",
|
552 |
-
" main()"
|
553 |
-
],
|
554 |
-
"metadata": {
|
555 |
-
"colab": {
|
556 |
-
"base_uri": "https://localhost:8080/"
|
557 |
-
},
|
558 |
-
"id": "VkkPD9QA-2Z8",
|
559 |
-
"outputId": "4533385c-10e3-474e-bba3-51d40466b07c"
|
560 |
-
},
|
561 |
-
"execution_count": 5,
|
562 |
-
"outputs": [
|
563 |
-
{
|
564 |
-
"output_type": "stream",
|
565 |
-
"name": "stdout",
|
566 |
-
"text": [
|
567 |
-
" Short Answer Question Generator\n",
|
568 |
-
"Enter your context (or press Enter to use sample): India has numerous national parks dedicated to preserving wildlife and biodiversity. Some of the most famous include Jim Corbett National Park in Uttarakhand, known for tigers; Kaziranga National Park in Assam, home to the one-horned rhinoceros; and Sundarbans in West Bengal, famous for mangrove forests and Royal Bengal Tigers. These parks also support eco-tourism and help protect endangered species and fragile ecosystems.\n",
|
569 |
-
"Number of questions to generate (default 6): 3\n",
|
570 |
-
"\n",
|
571 |
-
"Generating 3 short answer questions...\n",
|
572 |
-
"\n",
|
573 |
-
"Question 1: [CLASSIFIED: HARD] (factual)\n",
|
574 |
-
"Q: What national park in Assam is home to the one-horned rhinoceros?\n",
|
575 |
-
"A: Kaziranga National Park\n",
|
576 |
-
"Expected Length: brief (few words)\n",
|
577 |
-
"\n",
|
578 |
-
"Question 2: [CLASSIFIED: HARD] (factual)\n",
|
579 |
-
"Q: What national park is home to the one-horned rhinoceros?\n",
|
580 |
-
"A: Kaziranga National Park\n",
|
581 |
-
"Expected Length: brief (few words)\n",
|
582 |
-
"\n",
|
583 |
-
"Question 3: [CLASSIFIED: MEDIUM] (factual)\n",
|
584 |
-
"Q: Which type of ecosystems are protected by national parks?\n",
|
585 |
-
"A: fragile ecosystems\n",
|
586 |
-
"Expected Length: brief (few words)\n",
|
587 |
-
"\n",
|
588 |
-
"Generation complete!\n"
|
589 |
-
]
|
590 |
-
}
|
591 |
-
]
|
592 |
-
},
|
593 |
-
{
|
594 |
-
"cell_type": "code",
|
595 |
-
"source": [],
|
596 |
-
"metadata": {
|
597 |
-
"id": "EFvGMWctDueq"
|
598 |
-
},
|
599 |
-
"execution_count": null,
|
600 |
-
"outputs": []
|
601 |
-
}
|
602 |
-
]
|
603 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
true_false_generator.ipynb
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"nbformat": 4,
|
3 |
-
"nbformat_minor": 0,
|
4 |
-
"metadata": {
|
5 |
-
"colab": {
|
6 |
-
"provenance": [],
|
7 |
-
"gpuType": "T4",
|
8 |
-
"authorship_tag": "ABX9TyPvN5XJaoOJQM9ow757H9XQ",
|
9 |
-
"include_colab_link": true
|
10 |
-
},
|
11 |
-
"kernelspec": {
|
12 |
-
"name": "python3",
|
13 |
-
"display_name": "Python 3"
|
14 |
-
},
|
15 |
-
"language_info": {
|
16 |
-
"name": "python"
|
17 |
-
},
|
18 |
-
"accelerator": "GPU"
|
19 |
-
},
|
20 |
-
"cells": [
|
21 |
-
{
|
22 |
-
"cell_type": "markdown",
|
23 |
-
"metadata": {
|
24 |
-
"id": "view-in-github",
|
25 |
-
"colab_type": "text"
|
26 |
-
},
|
27 |
-
"source": [
|
28 |
-
"<a href=\"https://colab.research.google.com/github/DishaKushwah/custom-quiz-generator/blob/main/true_false_generator.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
29 |
-
]
|
30 |
-
},
|
31 |
-
{
|
32 |
-
"cell_type": "code",
|
33 |
-
"execution_count": null,
|
34 |
-
"metadata": {
|
35 |
-
"colab": {
|
36 |
-
"base_uri": "https://localhost:8080/"
|
37 |
-
},
|
38 |
-
"id": "nq6K-9RY_KQa",
|
39 |
-
"outputId": "4ad2a243-b1b1-482d-88ae-da17878ae13c"
|
40 |
-
},
|
41 |
-
"outputs": [
|
42 |
-
{
|
43 |
-
"output_type": "stream",
|
44 |
-
"name": "stderr",
|
45 |
-
"text": [
|
46 |
-
"Device set to use cpu\n"
|
47 |
-
]
|
48 |
-
},
|
49 |
-
{
|
50 |
-
"output_type": "stream",
|
51 |
-
"name": "stdout",
|
52 |
-
"text": [
|
53 |
-
"Enter the context for true/false generation: Currency is a system of money used for buying and selling goods and services. It has evolved from the barter system to coins, paper money, and now digital payments. International trade allows countries to exchange resources, goods, and services, often using foreign currencies like the US Dollar or Euro. Exchange rates determine how much one currency is worth compared to another. Organizations like the World Trade Organization (WTO) regulate global trade practices.\n",
|
54 |
-
"Number of questions to generate (default 5): 5\n",
|
55 |
-
"Generating 5 questions from context...\n",
|
56 |
-
"\n",
|
57 |
-
"--- Generated Questions ---\n",
|
58 |
-
"TRUE/FALSE QUESTIONS\n",
|
59 |
-
"\n",
|
60 |
-
"1. Exchange rates do not determine how much one currency is worth compared to another.\n",
|
61 |
-
" Answer: FALSE\n",
|
62 |
-
" Explanation: According to the context, the correct statement is: Exchange rates determine how much one currency is worth compared to another.\n",
|
63 |
-
"\n",
|
64 |
-
"2. It has not evolved from the barter system to coins, paper money, and now digital payments.\n",
|
65 |
-
" Answer: FALSE\n",
|
66 |
-
" Explanation: According to the context, the correct statement is: It has evolved from the barter system to coins, paper money, and now digital payments.\n",
|
67 |
-
"\n",
|
68 |
-
"3. It has evolved from the barter system to coins, paper money, and now digital payments.\n",
|
69 |
-
" Answer: TRUE\n",
|
70 |
-
" Explanation: Based on the context: It has evolved from the barter system to coins, paper money, and now digital payments.\n",
|
71 |
-
"\n",
|
72 |
-
"4. A currency is a system of money used for buying and selling goods and services.\n",
|
73 |
-
" Answer: FALSE\n",
|
74 |
-
" Explanation: According to the context, the correct statement is: Currency is a system of money used for buying and selling goods and services.\n",
|
75 |
-
"\n",
|
76 |
-
"5. True or False: Currency is a system of money used for buying and selling goods and services.\n",
|
77 |
-
" Answer: TRUE\n",
|
78 |
-
" Explanation: Based on the context: Currency is a system of money used for buying and selling goods and services.\n",
|
79 |
-
"Questions saved to questions.json\n"
|
80 |
-
]
|
81 |
-
}
|
82 |
-
],
|
83 |
-
"source": [
|
84 |
-
"#true false\n",
|
85 |
-
"import torch\n",
|
86 |
-
"from transformers import (AutoTokenizer, AutoModelForCausalLM, T5Tokenizer, T5ForConditionalGeneration, pipeline)\n",
|
87 |
-
"import nltk\n",
|
88 |
-
"import random\n",
|
89 |
-
"import re\n",
|
90 |
-
"import json\n",
|
91 |
-
"from typing import List, Dict, Tuple\n",
|
92 |
-
"import warnings\n",
|
93 |
-
"warnings.filterwarnings('ignore')\n",
|
94 |
-
"\n",
|
95 |
-
"class TrueFalseQuestionGenerator:\n",
|
96 |
-
" def __init__(self, model_name: str = \"google/flan-t5-large\"):\n",
|
97 |
-
" self.model_name = model_name\n",
|
98 |
-
" self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
99 |
-
" # Load tokenizer and model\n",
|
100 |
-
" if \"t5\" in model_name.lower():\n",
|
101 |
-
" self.tokenizer = T5Tokenizer.from_pretrained(model_name)\n",
|
102 |
-
" self.model = T5ForConditionalGeneration.from_pretrained(model_name)\n",
|
103 |
-
" else:\n",
|
104 |
-
" self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
105 |
-
" self.model = AutoModelForCausalLM.from_pretrained(model_name)\n",
|
106 |
-
" self.model.to(self.device)\n",
|
107 |
-
" self.model.eval()\n",
|
108 |
-
"\n",
|
109 |
-
" # Initialize question generation pipeline\n",
|
110 |
-
" self.question_pipeline = pipeline(\n",
|
111 |
-
" \"text2text-generation\",\n",
|
112 |
-
" model=self.model,\n",
|
113 |
-
" tokenizer=self.tokenizer,\n",
|
114 |
-
" device=0 if torch.cuda.is_available() else -1)\n",
|
115 |
-
"\n",
|
116 |
-
" def extract_key_facts(self, context: str) -> List[str]:\n",
|
117 |
-
" # Split the context into sentences using regex based on common punctuation\n",
|
118 |
-
" sentences = re.split(r'(?<=[.!?])\\s+', context)\n",
|
119 |
-
" # Filter sentences that are likely to contain factual information\n",
|
120 |
-
" factual_sentences = []\n",
|
121 |
-
" for sentence in sentences:\n",
|
122 |
-
" # Look for sentences with specific patterns that indicate facts\n",
|
123 |
-
" if (len(sentence.split()) > 5 and\n",
|
124 |
-
" any(keyword in sentence.lower() for keyword in\n",
|
125 |
-
" ['is', 'are', 'was', 'were', 'has', 'have', 'can', 'will','does', 'did', 'contains', 'includes', 'located', 'founded','established', 'invented', 'discovered', 'developed', 'used', 'use'])):\n",
|
126 |
-
" factual_sentences.append(sentence.strip())\n",
|
127 |
-
" return factual_sentences[:10] # Limit to 10 key facts\n",
|
128 |
-
"\n",
|
129 |
-
" def generate_true_questions(self, context: str, num_questions: int = 5) -> List[Dict]:\n",
|
130 |
-
" facts = self.extract_key_facts(context)\n",
|
131 |
-
" true_questions = []\n",
|
132 |
-
" # Ensure we don't try to generate more questions than available facts\n",
|
133 |
-
" num_to_generate = min(num_questions, len(facts))\n",
|
134 |
-
" for i, fact in enumerate(facts[:num_to_generate]):\n",
|
135 |
-
" # Create prompts for generating true/false questions\n",
|
136 |
-
" prompts = [f\"Create a statement that can be answered true or false based on: {fact}\"]\n",
|
137 |
-
" prompt = random.choice(prompts)\n",
|
138 |
-
"\n",
|
139 |
-
" try:\n",
|
140 |
-
" # Generate question using the model\n",
|
141 |
-
" response = self.question_pipeline(prompt, max_new_tokens=100, num_return_sequences=1, temperature=0.7, do_sample=True, pad_token_id=self.tokenizer.eos_token_id)\n",
|
142 |
-
" generated_text = response[0]['generated_text'].strip()\n",
|
143 |
-
" question = self.clean_question(generated_text)\n",
|
144 |
-
" if question:\n",
|
145 |
-
" true_questions.append({'question': question,'answer': True,'explanation': f\"Based on the context: {fact}\",'source_fact': fact })\n",
|
146 |
-
" except Exception as e:\n",
|
147 |
-
" print(f\"Error generating question {i+1}: {str(e)}\")\n",
|
148 |
-
" continue\n",
|
149 |
-
" return true_questions\n",
|
150 |
-
"\n",
|
151 |
-
" def generate_false_questions(self, context: str, num_questions: int = 5) -> List[Dict]:\n",
|
152 |
-
" \"\"\"\n",
|
153 |
-
" Generate FALSE questions by modifying facts from the context\n",
|
154 |
-
" \"\"\"\n",
|
155 |
-
" facts = self.extract_key_facts(context)\n",
|
156 |
-
" false_questions = []\n",
|
157 |
-
" num_to_generate = min(num_questions, len(facts))\n",
|
158 |
-
"\n",
|
159 |
-
" for i, fact in enumerate(facts[:num_to_generate]):\n",
|
160 |
-
" # Create prompts for generating false statements that are plausible but contradict the fact\n",
|
161 |
-
" prompts = [f\"Create a plausible false statement based on this fact: {fact}\",f\"Make this statement false by changing a key detail: {fact}\"]\n",
|
162 |
-
" prompt = random.choice(prompts)\n",
|
163 |
-
"\n",
|
164 |
-
" try:\n",
|
165 |
-
" response = self.question_pipeline(prompt,max_new_tokens=100,num_return_sequences=1,temperature=0.8, do_sample=True,pad_token_id=self.tokenizer.eos_token_id)\n",
|
166 |
-
" generated_text = response[0]['generated_text'].strip()\n",
|
167 |
-
" question = self.clean_question(generated_text)\n",
|
168 |
-
" if question:\n",
|
169 |
-
" false_questions.append({'question': question,'answer': False,'explanation': f\"According to the context, the correct statement is: {fact}\",'source_fact': fact})\n",
|
170 |
-
" except Exception as e:\n",
|
171 |
-
" print(f\"Error generating false question {i+1}: {str(e)}\")\n",
|
172 |
-
" continue\n",
|
173 |
-
" return false_questions\n",
|
174 |
-
"\n",
|
175 |
-
" def clean_question(self, text: str) -> str:\n",
|
176 |
-
" # Remove common prefixes\n",
|
177 |
-
" text = re.sub(r'^(Question:|Q:|True/False:|Statement:)\\s*', '', text, flags=re.IGNORECASE)\n",
|
178 |
-
"\n",
|
179 |
-
" # Ensure the question ends with proper punctuation\n",
|
180 |
-
" text = text.strip()\n",
|
181 |
-
" if not text.endswith(('.', '!', '?')):\n",
|
182 |
-
" text += '.'\n",
|
183 |
-
"\n",
|
184 |
-
" # Capitalize first letter\n",
|
185 |
-
" if text:\n",
|
186 |
-
" text = text[0].upper() + text[1:]\n",
|
187 |
-
"\n",
|
188 |
-
" # Remove standalone \"True\" or \"False\" if it appears at the beginning\n",
|
189 |
-
" text = re.sub(r'^(True|False)\\.\\s*', '', text, flags=re.IGNORECASE)\n",
|
190 |
-
" text = re.sub(r'^(True|False)\\?\\s*', '', text, flags=re.IGNORECASE)\n",
|
191 |
-
" return text if len(text) > 10 else None\n",
|
192 |
-
"\n",
|
193 |
-
" def generate_questions(self, context: str, num_questions: int = 10) -> List[Dict]:\n",
|
194 |
-
" if not context.strip():\n",
|
195 |
-
" raise ValueError(\"Context cannot be empty\")\n",
|
196 |
-
" print(f\"Generating {num_questions} questions from context...\")\n",
|
197 |
-
"\n",
|
198 |
-
" # Split questions evenly between true and false\n",
|
199 |
-
" num_true = num_questions // 2\n",
|
200 |
-
" num_false = num_questions - num_true\n",
|
201 |
-
"\n",
|
202 |
-
" true_questions = self.generate_true_questions(context, num_true)\n",
|
203 |
-
" false_questions = self.generate_false_questions(context, num_false)\n",
|
204 |
-
"\n",
|
205 |
-
" # Combine and shuffle questions\n",
|
206 |
-
" all_questions = true_questions + false_questions\n",
|
207 |
-
" random.shuffle(all_questions)\n",
|
208 |
-
"\n",
|
209 |
-
" # Add question numbers\n",
|
210 |
-
" for i, question in enumerate(all_questions, 1):\n",
|
211 |
-
" question['id'] = i\n",
|
212 |
-
" return all_questions\n",
|
213 |
-
"\n",
|
214 |
-
" def display_questions(self, questions: List[Dict], show_answers: bool = False):\n",
|
215 |
-
" print(\"TRUE/FALSE QUESTIONS\")\n",
|
216 |
-
" for q in questions:\n",
|
217 |
-
" print(f\"\\n{q['id']}. {q['question']}\")\n",
|
218 |
-
" if show_answers:\n",
|
219 |
-
" answer_text = \"TRUE\" if q['answer'] else \"FALSE\"\n",
|
220 |
-
" print(f\" Answer: {answer_text}\")\n",
|
221 |
-
" print(f\" Explanation: {q['explanation']}\")\n",
|
222 |
-
"\n",
|
223 |
-
" def save_questions(self, questions: List[Dict], filename: str = \"questions.json\"):\n",
|
224 |
-
" with open(filename.strip(), 'w', encoding='utf-8') as f:\n",
|
225 |
-
" json.dump(questions, f, indent=2, ensure_ascii=False)\n",
|
226 |
-
" print(f\"Questions saved to {filename.strip()}\")\n",
|
227 |
-
"\n",
|
228 |
-
"# Example usage and testing\n",
|
229 |
-
"def main():\n",
|
230 |
-
" # Initialize the generator\n",
|
231 |
-
" generator = TrueFalseQuestionGenerator()\n",
|
232 |
-
" # Sample context for testing\n",
|
233 |
-
" sample_context =input(\"Enter the context for true/false generation: \")\n",
|
234 |
-
" try:\n",
|
235 |
-
" num_questions = int(input(\"Number of questions to generate (default 5): \") or \"5\")\n",
|
236 |
-
" except ValueError:\n",
|
237 |
-
" num_questions = 5\n",
|
238 |
-
" # Generate questions\n",
|
239 |
-
" questions = generator.generate_questions(sample_context, num_questions)\n",
|
240 |
-
" # Display questions\n",
|
241 |
-
" generator.display_questions(questions, show_answers=True)\n",
|
242 |
-
" # Save questions\n",
|
243 |
-
" generator.save_questions(questions, \"questions.json\")\n",
|
244 |
-
"\n",
|
245 |
-
"if __name__ == \"__main__\":\n",
|
246 |
-
" main()"
|
247 |
-
]
|
248 |
-
}
|
249 |
-
]
|
250 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
true_false_generator.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
#true and false logic
|
|
|
|
truefalse_quiz.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import nltk
|
3 |
+
from transformers import pipeline
|
4 |
+
from nltk.tokenize import sent_tokenize
|
5 |
+
|
6 |
+
# Download required tokenizer
|
7 |
+
nltk.download('punkt_tab', quiet=True)
|
8 |
+
|
9 |
+
# Load NLI model
|
10 |
+
nli = pipeline("text-classification", model="facebook/bart-large-mnli")
|
11 |
+
|
12 |
+
# Input validation
|
13 |
+
def validate_inputs(context, num_questions, difficulty):
|
14 |
+
if not context.strip():
|
15 |
+
raise ValueError("Context cannot be empty.")
|
16 |
+
sentences = sent_tokenize(context)
|
17 |
+
if len(sentences) < num_questions:
|
18 |
+
raise ValueError(f"Context has only {len(sentences)} sentences, but {num_questions} questions requested.")
|
19 |
+
if difficulty not in ["easy", "medium", "hard"]:
|
20 |
+
raise ValueError("Difficulty must be 'easy', 'medium', or 'hard'.")
|
21 |
+
return sentences
|
22 |
+
|
23 |
+
# Difficulty-based sentence modifier
|
24 |
+
def apply_noise(sentence: str, level: str) -> str:
|
25 |
+
if level == "easy":
|
26 |
+
return sentence
|
27 |
+
elif level == "medium":
|
28 |
+
if "Sun" in sentence:
|
29 |
+
return sentence.replace("Sun", "Moon")
|
30 |
+
return sentence.replace("is", "is not") if "is" in sentence else sentence
|
31 |
+
elif level == "hard":
|
32 |
+
if "eight" in sentence:
|
33 |
+
return sentence.replace("eight", "ten")
|
34 |
+
return sentence.replace("planets", "stars") if "planets" in sentence else sentence
|
35 |
+
return sentence
|
36 |
+
|
37 |
+
# Statement generator
|
38 |
+
def generate_statements(context, n, difficulty, sentences):
|
39 |
+
random.seed(42)
|
40 |
+
selected = random.sample(sentences, min(n * 2, len(sentences)))
|
41 |
+
final = []
|
42 |
+
for s in selected:
|
43 |
+
clean = s.strip()
|
44 |
+
modified = apply_noise(clean, difficulty)
|
45 |
+
label = "ENTAILMENT" if clean == modified else "CONTRADICTION"
|
46 |
+
final.append((modified, label))
|
47 |
+
if len(final) >= n:
|
48 |
+
break
|
49 |
+
return final
|
50 |
+
|
51 |
+
# Get valid user answer
|
52 |
+
def get_user_answer():
|
53 |
+
while True:
|
54 |
+
user = input("True or False? ").strip().lower()
|
55 |
+
if user in ["true", "false"]:
|
56 |
+
return user
|
57 |
+
print("Please enter 'true' or 'false'.")
|
58 |
+
|
59 |
+
# Main logic
|
60 |
+
try:
|
61 |
+
context = input(">> Enter context text: ")
|
62 |
+
num_questions = int(input("\n>> How many questions do you want to generate? "))
|
63 |
+
difficulty = input("\n>> Enter difficulty level (easy/medium/hard): ").strip().lower()
|
64 |
+
|
65 |
+
sentences = validate_inputs(context, num_questions, difficulty)
|
66 |
+
questions = generate_statements(context, num_questions, difficulty, sentences)
|
67 |
+
|
68 |
+
if len(questions) < num_questions:
|
69 |
+
print(f"Warning: Only {len(questions)} questions generated due to limited context.")
|
70 |
+
|
71 |
+
print("\n--- QUIZ STARTS ---\n")
|
72 |
+
score = 0
|
73 |
+
|
74 |
+
for idx, (statement, actual_label) in enumerate(questions, 1):
|
75 |
+
print(f"Q{idx}: {statement}")
|
76 |
+
user = get_user_answer()
|
77 |
+
|
78 |
+
# Format input for facebook/bart-large-mnli
|
79 |
+
input_text = f"{context} [SEP] {statement}"
|
80 |
+
result = nli(input_text)[0]
|
81 |
+
if result["label"] == "neutral":
|
82 |
+
print("Skipping ambiguous statement.\n")
|
83 |
+
continue
|
84 |
+
model_label = "ENTAILMENT" if result["label"] == "entailment" else "CONTRADICTION"
|
85 |
+
|
86 |
+
if model_label == "ENTAILMENT" and user == "true":
|
87 |
+
print("Correct!\n")
|
88 |
+
score += 1
|
89 |
+
elif model_label == "CONTRADICTION" and user == "false":
|
90 |
+
print("Correct!\n")
|
91 |
+
score += 1
|
92 |
+
else:
|
93 |
+
print(f"Incorrect! (Correct answer: {'True' if model_label == 'ENTAILMENT' else 'False'})\n")
|
94 |
+
|
95 |
+
print(f"\n--- Final Score: {score}/{len(questions)} ---")
|
96 |
+
|
97 |
+
except ValueError as e:
|
98 |
+
print(f"Error: {e}")
|