Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,9 +2,8 @@ from fastapi import FastAPI, HTTPException, Header, Depends
|
|
2 |
from pydantic import BaseModel
|
3 |
from typing import Optional, List
|
4 |
from datetime import datetime
|
5 |
-
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
6 |
import torch
|
7 |
-
from
|
8 |
import time
|
9 |
|
10 |
app = FastAPI()
|
@@ -14,11 +13,14 @@ API_KEYS = {
|
|
14 |
"your-secret-api-key": "user1" # In production, use a secure database
|
15 |
}
|
16 |
|
17 |
-
#
|
18 |
MODEL_NAME = "tuner007/pegasus_paraphrase"
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
22 |
|
23 |
class TextRequest(BaseModel):
|
24 |
text: str
|
@@ -30,16 +32,6 @@ class BatchRequest(BaseModel):
|
|
30 |
style: Optional[str] = "standard"
|
31 |
num_variations: Optional[int] = 1
|
32 |
|
33 |
-
def get_paraphrase_params(style: str):
|
34 |
-
"""Get model parameters based on style"""
|
35 |
-
params = {
|
36 |
-
"standard": {"temperature": 1.0, "top_k": 50, "top_p": 0.95},
|
37 |
-
"formal": {"temperature": 0.7, "top_k": 30, "top_p": 0.9},
|
38 |
-
"casual": {"temperature": 1.3, "top_k": 100, "top_p": 0.95},
|
39 |
-
"creative": {"temperature": 1.5, "top_k": 120, "top_p": 0.99},
|
40 |
-
}
|
41 |
-
return params.get(style, params["standard"])
|
42 |
-
|
43 |
async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
|
44 |
if api_key not in API_KEYS:
|
45 |
raise HTTPException(status_code=403, detail="Invalid API key")
|
@@ -48,7 +40,12 @@ async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
|
|
48 |
def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
|
49 |
try:
|
50 |
# Get parameters based on style
|
51 |
-
params =
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
# Tokenize the input text
|
54 |
inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
|
@@ -62,7 +59,6 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
|
|
62 |
num_beams=num_variations * 2,
|
63 |
temperature=params["temperature"],
|
64 |
top_k=params["top_k"],
|
65 |
-
top_p=params["top_p"],
|
66 |
do_sample=True,
|
67 |
early_stopping=True,
|
68 |
)
|
@@ -78,6 +74,10 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
|
|
78 |
except Exception as e:
|
79 |
raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")
|
80 |
|
|
|
|
|
|
|
|
|
81 |
@app.post("/api/paraphrase")
|
82 |
async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
|
83 |
try:
|
@@ -133,17 +133,4 @@ async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_
|
|
133 |
}
|
134 |
|
135 |
except Exception as e:
|
136 |
-
raise HTTPException(status_code=500, detail=str(e))
|
137 |
-
|
138 |
-
@app.get("/api/health")
|
139 |
-
async def health_check(api_key: str = Depends(verify_api_key)):
|
140 |
-
return {
|
141 |
-
"status": "healthy",
|
142 |
-
"model": MODEL_NAME,
|
143 |
-
"device": device,
|
144 |
-
"timestamp": datetime.now().isoformat()
|
145 |
-
}
|
146 |
-
|
147 |
-
if __name__ == "__main__":
|
148 |
-
import uvicorn
|
149 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
2 |
from pydantic import BaseModel
|
3 |
from typing import Optional, List
|
4 |
from datetime import datetime
|
|
|
5 |
import torch
|
6 |
+
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
|
7 |
import time
|
8 |
|
9 |
app = FastAPI()
|
|
|
13 |
"your-secret-api-key": "user1" # In production, use a secure database
|
14 |
}
|
15 |
|
16 |
+
# Initialize model and tokenizer with smaller model for Spaces
|
17 |
MODEL_NAME = "tuner007/pegasus_paraphrase"
|
18 |
+
print("Loading model and tokenizer...")
|
19 |
+
tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
|
20 |
+
model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
|
21 |
+
device = "cpu" # Force CPU for Spaces deployment
|
22 |
+
model = model.to(device)
|
23 |
+
print("Model and tokenizer loaded successfully!")
|
24 |
|
25 |
class TextRequest(BaseModel):
|
26 |
text: str
|
|
|
32 |
style: Optional[str] = "standard"
|
33 |
num_variations: Optional[int] = 1
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
|
36 |
if api_key not in API_KEYS:
|
37 |
raise HTTPException(status_code=403, detail="Invalid API key")
|
|
|
40 |
def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
|
41 |
try:
|
42 |
# Get parameters based on style
|
43 |
+
params = {
|
44 |
+
"standard": {"temperature": 1.0, "top_k": 50},
|
45 |
+
"formal": {"temperature": 0.7, "top_k": 30},
|
46 |
+
"casual": {"temperature": 1.3, "top_k": 100},
|
47 |
+
"creative": {"temperature": 1.5, "top_k": 120},
|
48 |
+
}.get(style, {"temperature": 1.0, "top_k": 50})
|
49 |
|
50 |
# Tokenize the input text
|
51 |
inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
|
|
|
59 |
num_beams=num_variations * 2,
|
60 |
temperature=params["temperature"],
|
61 |
top_k=params["top_k"],
|
|
|
62 |
do_sample=True,
|
63 |
early_stopping=True,
|
64 |
)
|
|
|
74 |
except Exception as e:
|
75 |
raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")
|
76 |
|
77 |
+
@app.get("/")
|
78 |
+
async def root():
|
79 |
+
return {"message": "Paraphrase API is running. Use /docs for API documentation."}
|
80 |
+
|
81 |
@app.post("/api/paraphrase")
|
82 |
async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
|
83 |
try:
|
|
|
133 |
}
|
134 |
|
135 |
except Exception as e:
|
136 |
+
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|