lahiruchamika27 commited on
Commit
8355ff9
·
verified ·
1 Parent(s): 16297d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -60,7 +60,13 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
60
 
61
  # Tokenize the input text
62
  input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
63
-
 
 
 
 
 
 
64
  # Generate paraphrases
65
  with torch.no_grad():
66
  outputs = model.generate(
@@ -71,16 +77,16 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
71
  temperature=params["temperature"],
72
  top_k=params["top_k"],
73
  diversity_penalty=params["diversity_penalty"],
74
- num_beam_groups=min(num_variations, 4) if num_variations > 1 else 1,
75
- do_sample=True
76
  )
77
-
78
  # Decode the generated outputs
79
  paraphrases = [
80
  tokenizer.decode(output, skip_special_tokens=True)
81
  for output in outputs
82
  ]
83
-
84
  return paraphrases
85
 
86
  except Exception as e:
 
60
 
61
  # Tokenize the input text
62
  input_ids = tokenizer.encode(text, return_tensors="pt").to(device)
63
+
64
+ # Determine beam groups
65
+ beam_groups = min(num_variations, 4) if num_variations > 1 else 1
66
+
67
+ # If using diverse beam search, disable sampling
68
+ do_sample = False if beam_groups > 1 else True
69
+
70
  # Generate paraphrases
71
  with torch.no_grad():
72
  outputs = model.generate(
 
77
  temperature=params["temperature"],
78
  top_k=params["top_k"],
79
  diversity_penalty=params["diversity_penalty"],
80
+ num_beam_groups=beam_groups,
81
+ do_sample=do_sample # <-- Fix applied here
82
  )
83
+
84
  # Decode the generated outputs
85
  paraphrases = [
86
  tokenizer.decode(output, skip_special_tokens=True)
87
  for output in outputs
88
  ]
89
+
90
  return paraphrases
91
 
92
  except Exception as e: