Avinash250325 commited on
Commit
a359070
·
verified ·
1 Parent(s): ed40669

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +16 -7
backend.py CHANGED
@@ -5,7 +5,7 @@ model_name = "./T5base_Question_Generation"
5
  tokenizer = T5Tokenizer.from_pretrained(model_name)
6
  model = T5ForConditionalGeneration.from_pretrained(model_name)
7
 
8
- def get_question(tag, difficulty, context, answer="", num_questions=1, max_length=150):
9
  """
10
  Generate questions using the fine-tuned T5 model
11
 
@@ -29,12 +29,21 @@ def get_question(tag, difficulty, context, answer="", num_questions=1, max_lengt
29
 
30
  # Decide generation strategy
31
  if num_questions == 1:
32
- output = model.generate(
33
- input_ids=features['input_ids'],
34
- attention_mask=features['attention_mask'],
35
- max_length=max_length,
36
- do_sample=False
37
- )
 
 
 
 
 
 
 
 
 
38
  else:
39
  output = model.generate(
40
  input_ids=features['input_ids'],
 
5
  tokenizer = T5Tokenizer.from_pretrained(model_name)
6
  model = T5ForConditionalGeneration.from_pretrained(model_name)
7
 
8
+ def get_question(tag, difficulty, context, answer="", num_questions=1, use_beam_search=False, num_beams=3, max_length=150):
9
  """
10
  Generate questions using the fine-tuned T5 model
11
 
 
29
 
30
  # Decide generation strategy
31
  if num_questions == 1:
32
+ if use_beam_search:
33
+ output = model.generate(
34
+ input_ids=features['input_ids'],
35
+ attention_mask=features['attention_mask'],
36
+ max_length=max_length,
37
+ num_beams=num_beams,
38
+ early_stopping=False
39
+ )
40
+ else:
41
+ output = model.generate(
42
+ input_ids=features['input_ids'],
43
+ attention_mask=features['attention_mask'],
44
+ max_length=max_length,
45
+ do_sample=False
46
+ )
47
  else:
48
  output = model.generate(
49
  input_ids=features['input_ids'],