lahiruchamika27 commited on
Commit
937e92f
·
verified ·
1 Parent(s): fa6405a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -32
app.py CHANGED
@@ -2,9 +2,8 @@ from fastapi import FastAPI, HTTPException, Header, Depends
2
  from pydantic import BaseModel
3
  from typing import Optional, List
4
  from datetime import datetime
5
- from transformers import PegasusForConditionalGeneration, PegasusTokenizer
6
  import torch
7
- from typing import List
8
  import time
9
 
10
  app = FastAPI()
@@ -14,11 +13,14 @@ API_KEYS = {
14
  "your-secret-api-key": "user1" # In production, use a secure database
15
  }
16
 
17
- # Load model and tokenizer globally
18
  MODEL_NAME = "tuner007/pegasus_paraphrase"
19
- tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME)
20
- model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
22
 
23
  class TextRequest(BaseModel):
24
  text: str
@@ -30,16 +32,6 @@ class BatchRequest(BaseModel):
30
  style: Optional[str] = "standard"
31
  num_variations: Optional[int] = 1
32
 
33
- def get_paraphrase_params(style: str):
34
- """Get model parameters based on style"""
35
- params = {
36
- "standard": {"temperature": 1.0, "top_k": 50, "top_p": 0.95},
37
- "formal": {"temperature": 0.7, "top_k": 30, "top_p": 0.9},
38
- "casual": {"temperature": 1.3, "top_k": 100, "top_p": 0.95},
39
- "creative": {"temperature": 1.5, "top_k": 120, "top_p": 0.99},
40
- }
41
- return params.get(style, params["standard"])
42
-
43
  async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
44
  if api_key not in API_KEYS:
45
  raise HTTPException(status_code=403, detail="Invalid API key")
@@ -48,7 +40,12 @@ async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
48
  def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
49
  try:
50
  # Get parameters based on style
51
- params = get_paraphrase_params(style)
 
 
 
 
 
52
 
53
  # Tokenize the input text
54
  inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
@@ -62,7 +59,6 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
62
  num_beams=num_variations * 2,
63
  temperature=params["temperature"],
64
  top_k=params["top_k"],
65
- top_p=params["top_p"],
66
  do_sample=True,
67
  early_stopping=True,
68
  )
@@ -78,6 +74,10 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
78
  except Exception as e:
79
  raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")
80
 
 
 
 
 
81
  @app.post("/api/paraphrase")
82
  async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
83
  try:
@@ -133,17 +133,4 @@ async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_
133
  }
134
 
135
  except Exception as e:
136
- raise HTTPException(status_code=500, detail=str(e))
137
-
138
- @app.get("/api/health")
139
- async def health_check(api_key: str = Depends(verify_api_key)):
140
- return {
141
- "status": "healthy",
142
- "model": MODEL_NAME,
143
- "device": device,
144
- "timestamp": datetime.now().isoformat()
145
- }
146
-
147
- if __name__ == "__main__":
148
- import uvicorn
149
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  from pydantic import BaseModel
3
  from typing import Optional, List
4
  from datetime import datetime
 
5
  import torch
6
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
7
  import time
8
 
9
  app = FastAPI()
 
13
  "your-secret-api-key": "user1" # In production, use a secure database
14
  }
15
 
16
+ # Initialize model and tokenizer with smaller model for Spaces
17
  MODEL_NAME = "tuner007/pegasus_paraphrase"
18
+ print("Loading model and tokenizer...")
19
+ tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
20
+ model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
21
+ device = "cpu" # Force CPU for Spaces deployment
22
+ model = model.to(device)
23
+ print("Model and tokenizer loaded successfully!")
24
 
25
  class TextRequest(BaseModel):
26
  text: str
 
32
  style: Optional[str] = "standard"
33
  num_variations: Optional[int] = 1
34
 
 
 
 
 
 
 
 
 
 
 
35
  async def verify_api_key(api_key: str = Header(..., name="X-API-Key")):
36
  if api_key not in API_KEYS:
37
  raise HTTPException(status_code=403, detail="Invalid API key")
 
40
  def generate_paraphrase(text: str, style: str = "standard", num_variations: int = 1) -> List[str]:
41
  try:
42
  # Get parameters based on style
43
+ params = {
44
+ "standard": {"temperature": 1.0, "top_k": 50},
45
+ "formal": {"temperature": 0.7, "top_k": 30},
46
+ "casual": {"temperature": 1.3, "top_k": 100},
47
+ "creative": {"temperature": 1.5, "top_k": 120},
48
+ }.get(style, {"temperature": 1.0, "top_k": 50})
49
 
50
  # Tokenize the input text
51
  inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
 
59
  num_beams=num_variations * 2,
60
  temperature=params["temperature"],
61
  top_k=params["top_k"],
 
62
  do_sample=True,
63
  early_stopping=True,
64
  )
 
74
  except Exception as e:
75
  raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")
76
 
77
+ @app.get("/")
78
+ async def root():
79
+ return {"message": "Paraphrase API is running. Use /docs for API documentation."}
80
+
81
  @app.post("/api/paraphrase")
82
  async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key)):
83
  try:
 
133
  }
134
 
135
  except Exception as e:
136
+ raise HTTPException(status_code=500, detail=str(e))