Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,10 +3,10 @@ from pydantic import BaseModel
|
|
3 |
from typing import Optional, List
|
4 |
from datetime import datetime
|
5 |
import torch
|
|
|
6 |
import time
|
7 |
import traceback
|
8 |
import logging
|
9 |
-
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
10 |
|
11 |
# Configure logging
|
12 |
logging.basicConfig(level=logging.INFO)
|
@@ -19,12 +19,12 @@ API_KEYS = {
|
|
19 |
"bdLFqk4IcYmRE2ONZeCts4DWrqkpqQxW": "user1" # In production, use a secure database
|
20 |
}
|
21 |
|
22 |
-
# Initialize model and tokenizer - using a dedicated paraphrasing model
|
23 |
-
MODEL_NAME = "
|
24 |
try:
|
25 |
print("Loading model and tokenizer...")
|
26 |
-
tokenizer =
|
27 |
-
model =
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
model = model.to(device)
|
30 |
print(f"Model and tokenizer loaded successfully on {device}!")
|
@@ -52,41 +52,47 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
|
|
52 |
try:
|
53 |
# Get parameters based on style
|
54 |
params = {
|
55 |
-
"standard": {"temperature": 1.0, "
|
56 |
-
"formal": {"temperature": 0.7, "
|
57 |
-
"casual": {"temperature": 1.
|
58 |
-
"creative": {"temperature": 1.5, "
|
59 |
-
}.get(style, {"temperature": 1.0, "
|
|
|
|
|
|
|
60 |
|
61 |
# Tokenize the input text
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
70 |
# Generate paraphrases
|
71 |
with torch.no_grad():
|
72 |
outputs = model.generate(
|
73 |
-
input_ids,
|
74 |
-
|
|
|
75 |
num_return_sequences=num_variations,
|
76 |
-
num_beams=num_variations
|
77 |
temperature=params["temperature"],
|
|
|
78 |
top_k=params["top_k"],
|
79 |
-
|
80 |
-
|
81 |
-
do_sample=do_sample # <-- Fix applied here
|
82 |
)
|
83 |
-
|
84 |
# Decode the generated outputs
|
85 |
paraphrases = [
|
86 |
-
tokenizer.decode(output, skip_special_tokens=True)
|
87 |
for output in outputs
|
88 |
]
|
89 |
-
|
90 |
return paraphrases
|
91 |
|
92 |
except Exception as e:
|
|
|
3 |
from typing import Optional, List
|
4 |
from datetime import datetime
|
5 |
import torch
|
6 |
+
from transformers import T5ForConditionalGeneration, T5Tokenizer
|
7 |
import time
|
8 |
import traceback
|
9 |
import logging
|
|
|
10 |
|
11 |
# Configure logging
|
12 |
logging.basicConfig(level=logging.INFO)
|
|
|
19 |
"bdLFqk4IcYmRE2ONZeCts4DWrqkpqQxW": "user1" # In production, use a secure database
|
20 |
}
|
21 |
|
22 |
+
# Initialize model and tokenizer - using a dedicated T5 paraphrasing model
|
23 |
+
MODEL_NAME = "Vamsi/T5_Paraphrase_Paws" # Specifically fine-tuned for paraphrasing
|
24 |
try:
|
25 |
print("Loading model and tokenizer...")
|
26 |
+
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
|
27 |
+
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
|
28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
model = model.to(device)
|
30 |
print(f"Model and tokenizer loaded successfully on {device}!")
|
|
|
52 |
try:
|
53 |
# Get parameters based on style
|
54 |
params = {
|
55 |
+
"standard": {"temperature": 1.0, "top_p": 0.9, "top_k": 50},
|
56 |
+
"formal": {"temperature": 0.7, "top_p": 0.85, "top_k": 40},
|
57 |
+
"casual": {"temperature": 1.2, "top_p": 0.95, "top_k": 60},
|
58 |
+
"creative": {"temperature": 1.5, "top_p": 0.98, "top_k": 80},
|
59 |
+
}.get(style, {"temperature": 1.0, "top_p": 0.9, "top_k": 50})
|
60 |
+
|
61 |
+
# T5 models require a specific text format for tasks
|
62 |
+
text_to_paraphrase = f"paraphrase: {text} </s>"
|
63 |
|
64 |
# Tokenize the input text
|
65 |
+
encoding = tokenizer.encode_plus(
|
66 |
+
text_to_paraphrase,
|
67 |
+
padding="longest",
|
68 |
+
max_length=256,
|
69 |
+
truncation=True,
|
70 |
+
return_tensors="pt"
|
71 |
+
)
|
72 |
+
input_ids = encoding["input_ids"].to(device)
|
73 |
+
attention_mask = encoding["attention_mask"].to(device)
|
74 |
+
|
75 |
# Generate paraphrases
|
76 |
with torch.no_grad():
|
77 |
outputs = model.generate(
|
78 |
+
input_ids=input_ids,
|
79 |
+
attention_mask=attention_mask,
|
80 |
+
max_length=256,
|
81 |
num_return_sequences=num_variations,
|
82 |
+
num_beams=num_variations * 2,
|
83 |
temperature=params["temperature"],
|
84 |
+
top_p=params["top_p"],
|
85 |
top_k=params["top_k"],
|
86 |
+
do_sample=True,
|
87 |
+
early_stopping=True,
|
|
|
88 |
)
|
89 |
+
|
90 |
# Decode the generated outputs
|
91 |
paraphrases = [
|
92 |
+
tokenizer.decode(output, skip_special_tokens=True)
|
93 |
for output in outputs
|
94 |
]
|
95 |
+
|
96 |
return paraphrases
|
97 |
|
98 |
except Exception as e:
|