lahiruchamika27 commited on
Commit
01807b0
·
verified ·
1 Parent(s): 488ca19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -71
app.py CHANGED
@@ -3,14 +3,8 @@ from pydantic import BaseModel
3
  from typing import Optional, List
4
  from datetime import datetime
5
  import torch
6
- from transformers import T5ForConditionalGeneration, T5Tokenizer
7
  import time
8
- import traceback
9
- import logging
10
-
11
- # Configure logging
12
- logging.basicConfig(level=logging.INFO)
13
- logger = logging.getLogger(__name__)
14
 
15
  app = FastAPI()
16
 
@@ -19,19 +13,14 @@ API_KEYS = {
19
  "bdLFqk4IcYmRE2ONZeCts4DWrqkpqQxW": "user1" # In production, use a secure database
20
  }
21
 
22
- # Initialize model and tokenizer - using a dedicated T5 paraphrasing model
23
- MODEL_NAME = "Vamsi/T5_Paraphrase_Paws" # Specifically fine-tuned for paraphrasing
24
- try:
25
- print("Loading model and tokenizer...")
26
- tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
27
- model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
28
- device = "cuda" if torch.cuda.is_available() else "cpu"
29
- model = model.to(device)
30
- print(f"Model and tokenizer loaded successfully on {device}!")
31
- except Exception as e:
32
- error_msg = f"Error loading model: {str(e)}\n{traceback.format_exc()}"
33
- print(error_msg)
34
- logger.error(error_msg)
35
 
36
  class TextRequest(BaseModel):
37
  text: str
@@ -52,36 +41,23 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
52
  try:
53
  # Get parameters based on style
54
  params = {
55
- "standard": {"temperature": 1.0, "top_p": 0.9, "top_k": 50},
56
- "formal": {"temperature": 0.7, "top_p": 0.85, "top_k": 40},
57
- "casual": {"temperature": 1.2, "top_p": 0.95, "top_k": 60},
58
- "creative": {"temperature": 1.5, "top_p": 0.98, "top_k": 80},
59
- }.get(style, {"temperature": 1.0, "top_p": 0.9, "top_k": 50})
60
-
61
- # T5 models require a specific text format for tasks
62
- text_to_paraphrase = f"paraphrase: {text} </s>"
63
 
64
  # Tokenize the input text
65
- encoding = tokenizer.encode_plus(
66
- text_to_paraphrase,
67
- padding="longest",
68
- max_length=256,
69
- truncation=True,
70
- return_tensors="pt"
71
- )
72
- input_ids = encoding["input_ids"].to(device)
73
- attention_mask = encoding["attention_mask"].to(device)
74
 
75
  # Generate paraphrases
76
  with torch.no_grad():
77
  outputs = model.generate(
78
- input_ids=input_ids,
79
- attention_mask=attention_mask,
80
- max_length=256,
81
  num_return_sequences=num_variations,
82
  num_beams=num_variations * 2,
83
  temperature=params["temperature"],
84
- top_p=params["top_p"],
85
  top_k=params["top_k"],
86
  do_sample=True,
87
  early_stopping=True,
@@ -89,16 +65,14 @@ def generate_paraphrase(text: str, style: str = "standard", num_variations: int
89
 
90
  # Decode the generated outputs
91
  paraphrases = [
92
- tokenizer.decode(output, skip_special_tokens=True)
93
  for output in outputs
94
  ]
95
 
96
  return paraphrases
97
 
98
  except Exception as e:
99
- error_msg = f"Paraphrase generation error: {str(e)}\n{traceback.format_exc()}"
100
- logger.error(error_msg)
101
- raise HTTPException(status_code=500, detail=error_msg)
102
 
103
  @app.get("/")
104
  async def root():
@@ -127,9 +101,7 @@ async def paraphrase(request: TextRequest, api_key: str = Depends(verify_api_key
127
  }
128
 
129
  except Exception as e:
130
- error_msg = f"API error: {str(e)}"
131
- logger.error(f"{error_msg}\n{traceback.format_exc()}")
132
- raise HTTPException(status_code=500, detail=error_msg)
133
 
134
  @app.post("/api/batch-paraphrase")
135
  async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
@@ -161,26 +133,4 @@ async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_
161
  }
162
 
163
  except Exception as e:
164
- error_msg = f"API error: {str(e)}"
165
- logger.error(f"{error_msg}\n{traceback.format_exc()}")
166
- raise HTTPException(status_code=500, detail=error_msg)
167
-
168
- # For testing/debugging the API
169
- @app.get("/api/test")
170
- async def test_endpoint():
171
- try:
172
- test_text = "The quick brown fox jumps over the lazy dog."
173
- result = generate_paraphrase(test_text, "standard", 1)
174
- return {
175
- "status": "success",
176
- "test_text": test_text,
177
- "paraphrased": result,
178
- "model": MODEL_NAME,
179
- "device": device
180
- }
181
- except Exception as e:
182
- return {
183
- "status": "error",
184
- "error": str(e),
185
- "traceback": traceback.format_exc()
186
- }
 
3
  from typing import Optional, List
4
  from datetime import datetime
5
  import torch
6
+ from transformers import PegasusForConditionalGeneration, PegasusTokenizer
7
  import time
 
 
 
 
 
 
8
 
9
  app = FastAPI()
10
 
 
13
  "bdLFqk4IcYmRE2ONZeCts4DWrqkpqQxW": "user1" # In production, use a secure database
14
  }
15
 
16
+ # Initialize model and tokenizer with smaller model for Spaces
17
+ MODEL_NAME = "tuner007/pegasus_paraphrase"
18
+ print("Loading model and tokenizer...")
19
+ tokenizer = PegasusTokenizer.from_pretrained(MODEL_NAME, cache_dir="model_cache")
20
+ model = PegasusForConditionalGeneration.from_pretrained(MODEL_NAME, cache_dir="model_cache")
21
+ device = "cpu" # Force CPU for Spaces deployment
22
+ model = model.to(device)
23
+ print("Model and tokenizer loaded successfully!")
 
 
 
 
 
24
 
25
  class TextRequest(BaseModel):
26
  text: str
 
41
  try:
42
  # Get parameters based on style
43
  params = {
44
+ "standard": {"temperature": 1.5, "top_k": 80},
45
+ "formal": {"temperature": 1.0, "top_k": 50},
46
+ "casual": {"temperature": 1.6, "top_k": 100},
47
+ "creative": {"temperature": 2.8, "top_k": 170},
48
+ }.get(style, {"temperature": 1.0, "top_k": 50})
 
 
 
49
 
50
  # Tokenize the input text
51
+ inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
52
 
53
  # Generate paraphrases
54
  with torch.no_grad():
55
  outputs = model.generate(
56
+ **inputs,
57
+ max_length=200,
 
58
  num_return_sequences=num_variations,
59
  num_beams=num_variations * 2,
60
  temperature=params["temperature"],
 
61
  top_k=params["top_k"],
62
  do_sample=True,
63
  early_stopping=True,
 
65
 
66
  # Decode the generated outputs
67
  paraphrases = [
68
+ tokenizer.decode(output, skip_special_tokens=True)
69
  for output in outputs
70
  ]
71
 
72
  return paraphrases
73
 
74
  except Exception as e:
75
+ raise HTTPException(status_code=500, detail=f"Paraphrase generation error: {str(e)}")
 
 
76
 
77
  @app.get("/")
78
  async def root():
 
101
  }
102
 
103
  except Exception as e:
104
+ raise HTTPException(status_code=500, detail=str(e))
 
 
105
 
106
  @app.post("/api/batch-paraphrase")
107
  async def batch_paraphrase(request: BatchRequest, api_key: str = Depends(verify_api_key)):
 
133
  }
134
 
135
  except Exception as e:
136
+ raise HTTPException(status_code=500, detail=str(e))