Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,14 +3,8 @@ from pydantic import BaseModel
|
|
3 |
from typing import Optional, List
|
4 |
from datetime import datetime
|
5 |
import torch
|
6 |
-
from transformers import
|
7 |
import time
|
8 |
-
import traceback
|
9 |
-
import logging
|
10 |
-
|
11 |
-
# Configure logging
|
12 |
-
logging.basicConfig(level=logging.INFO)
|
13 |
-
logger = logging.getLogger(__name__)
|
14 |
|
15 |
app = FastAPI()
|
16 |
|
@@ -19,19 +13,14 @@ API_KEYS = {
|
|
19 |
"bdLFqk4IcYmRE2ONZeCts4DWrqkpqQxW": "user1" # In production, use a secure database
|
20 |
}
|
21 |
|
22 |
-
# Initialize model and tokenizer
|
23 |
-
MODEL_NAME = "
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
print(f"Model and tokenizer loaded successfully on {device}!")
|
31 |
-
except Exception as e:
|
32 |
-
error_msg = f"Error loading model: {str(e)}\n{traceback.format_exc()}"
|
33 |
-
print(error_msg)
|
34 |
-
logger.error(error_msg)
|
35 |
|
36 |
class TextRequest(BaseModel):
|
37 |
text: str
|
@@ -52,36 +41,23 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
|
|
52 |
try:
|
53 |
# Get parameters based on style
|
54 |
params = {
|
55 |
-
"standard": {"temperature": 1.
|
56 |
-
"formal": {"temperature":
|
57 |
-
"casual": {"temperature": 1.
|
58 |
-
"creative": {"temperature":
|
59 |
-
}.get(style, {"temperature": 1.0, "
|
60 |
-
|
61 |
-
# T5 models require a specific text format for tasks
|
62 |
-
text_to_paraphrase = f"paraphrase: {text} </s>"
|
63 |
|
64 |
# Tokenize the input text
|
65 |
-
|
66 |
-
text_to_paraphrase,
|
67 |
-
padding="longest",
|
68 |
-
max_length=256,
|
69 |
-
truncation=True,
|
70 |
-
return_tensors="pt"
|
71 |
-
)
|
72 |
-
input_ids = encoding["input_ids"].to(device)
|
73 |
-
attention_mask = encoding["attention_mask"].to(device)
|
74 |
|
75 |
# Generate paraphrases
|
76 |
with torch.no_grad():
|
77 |
outputs = model.generate(
|
78 |
-
|
79 |
-
|
80 |
-
max_length=256,
|
81 |
num_return_sequences=num_variations,
|
82 |
num_beams=num_variations * 2,
|
83 |
temperature=params["temperature"],
|
84 |
-
top_p=params["top_p"],
|
85 |
top_k=params["top_k"],
|
86 |
do_sample=True,
|
87 |
early_stopping=True,
|
@@ -89,16 +65,14 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
|
|
89 |
|
90 |
# Decode the generated outputs
|
91 |
paraphrases = [
|
92 |
-
tokenizer.decode(output, skip_special_tokens=True)
|
93 |
for output in outputs
|
94 |
]
|
95 |
|
96 |
return paraphrases
|
97 |
|
98 |
except Exception as e:
|
99 |
-
|
100 |
-
logger.error(error_msg)
|
101 |
-
raise HTTPException(status_code=500, detail=error_msg)
|
102 |
|
103 |
@app.get("/")
|
104 |
async def root():
|
@@ -127,9 +101,7 @@ async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key
|
|
127 |
}
|
128 |
|
129 |
except Exception as e:
|
130 |
-
|
131 |
-
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
132 |
-
raise HTTPException(status_code=500, detail=error_msg)
|
133 |
|
134 |
@app.post("/api/batch-paraphrase")
|
135 |
async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
|
@@ -161,26 +133,4 @@ async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_
|
|
161 |
}
|
162 |
|
163 |
except Exception as e:
|
164 |
-
|
165 |
-
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
166 |
-
raise HTTPException(status_code=500, detail=error_msg)
|
167 |
-
|
168 |
-
# For testing/debugging the API
|
169 |
-
@app.get("/api/test")
|
170 |
-
async def test_endpoint():
|
171 |
-
try:
|
172 |
-
test_text = "The quick brown fox jumps over the lazy dog."
|
173 |
-
result = generate_paraphrase(test_text, "standard", 1)
|
174 |
-
return {
|
175 |
-
"status": "success",
|
176 |
-
"test_text": test_text,
|
177 |
-
"paraphrased": result,
|
178 |
-
"model": MODEL_NAME,
|
179 |
-
"device": device
|
180 |
-
}
|
181 |
-
except Exception as e:
|
182 |
-
return {
|
183 |
-
"status": "error",
|
184 |
-
"error": str(e),
|
185 |
-
"traceback": traceback.format_exc()
|
186 |
-
}
|
|
|
3 |
from typing import Optional, List
|
4 |
from datetime import datetime
|
5 |
import torch
|
6 |
+
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
7 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
app = FastAPI()
|
10 |
|
|
|
13 |
"bdLFqk4IcYmRE2ONZeCts4DWrqkpqQxW": "user1" # In production, use a secure database
|
14 |
}
|
15 |
|
16 |
+
# Initialize model and tokenizer with smaller model for Spaces
|
17 |
+
MODEL_NAME = "tuner007/pegasus_paraphrase"
|
18 |
+
print("Loading model and tokenizer...")
|
19 |
+
tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
|
20 |
+
model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
|
21 |
+
device = "cpu" # Force CPU for Spaces deployment
|
22 |
+
model = model.to(device)
|
23 |
+
print("Model and tokenizer loaded successfully!")
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
class TextRequest(BaseModel):
|
26 |
text: str
|
|
|
41 |
try:
|
42 |
# Get parameters based on style
|
43 |
params = {
|
44 |
+
"standard": {"temperature": 1.5, "top_k": 80},
|
45 |
+
"formal": {"temperature": 1.0, "top_k": 50},
|
46 |
+
"casual": {"temperature": 1.6, "top_k": 100},
|
47 |
+
"creative": {"temperature": 2.8, "top_k": 170},
|
48 |
+
}.get(style, {"temperature": 1.0, "top_k": 50})
|
|
|
|
|
|
|
49 |
|
50 |
# Tokenize the input text
|
51 |
+
inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# Generate paraphrases
|
54 |
with torch.no_grad():
|
55 |
outputs = model.generate(
|
56 |
+
**inputs,
|
57 |
+
max_length=200,
|
|
|
58 |
num_return_sequences=num_variations,
|
59 |
num_beams=num_variations * 2,
|
60 |
temperature=params["temperature"],
|
|
|
61 |
top_k=params["top_k"],
|
62 |
do_sample=True,
|
63 |
early_stopping=True,
|
|
|
65 |
|
66 |
# Decode the generated outputs
|
67 |
paraphrases = [
|
68 |
+
tokenizer.decode(output, skip_special_tokens=True)
|
69 |
for output in outputs
|
70 |
]
|
71 |
|
72 |
return paraphrases
|
73 |
|
74 |
except Exception as e:
|
75 |
+
raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")
|
|
|
|
|
76 |
|
77 |
@app.get("/")
|
78 |
async def root():
|
|
|
101 |
}
|
102 |
|
103 |
except Exception as e:
|
104 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
105 |
|
106 |
@app.post("/api/batch-paraphrase")
|
107 |
async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
|
|
|
133 |
}
|
134 |
|
135 |
except Exception as e:
|
136 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|