eligapris commited on
Commit
5cf5eb9
·
verified ·
1 Parent(s): deb8249

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +19 -0
  2. main.py +55 -0
  3. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install dependencies
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ # Create cache directory with proper permissions
10
+ RUN mkdir -p /app/model_cache && chmod 777 /app/model_cache
11
+
12
+ # Copy application code
13
+ COPY . .
14
+
15
+ # Expose the port your FastAPI app will run on
16
+ EXPOSE 7860
17
+
18
+ # Command to run the application
19
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Simple implementation for translation using the BART model
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from transformers import BartTokenizer, BartForConditionalGeneration
5
+
6
+ app = FastAPI()
7
+
8
+ # Define request model
9
+ class TranslationRequest(BaseModel):
10
+ text: str
11
+ max_length: int = 150
12
+ min_length: int = 40
13
+
14
+ # Download and cache the model during initialization
15
+ # This happens only once when the app starts
16
+ try:
17
+ # Explicitly download to a specific directory with proper error handling
18
+ cache_dir = "./model_cache"
19
+ model_name = "facebook/bart-large-cnn"
20
+
21
+ print(f"Loading tokenizer from {model_name}...")
22
+ tokenizer = BartTokenizer.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False)
23
+
24
+ print(f"Loading model from {model_name}...")
25
+ model = BartForConditionalGeneration.from_pretrained(model_name, cache_dir=cache_dir, local_files_only=False)
26
+
27
+ print("Model and tokenizer loaded successfully!")
28
+ except Exception as e:
29
+ print(f"Error loading model: {str(e)}")
30
+ raise
31
+
32
+ @app.post("/summarize/")
33
+ async def translate_text(request: TranslationRequest):
34
+ # Process the input text
35
+ inputs = tokenizer(request.text, return_tensors="pt", max_length=1024, truncation=True)
36
+
37
+ # Generate summary
38
+ summary_ids = model.generate(
39
+ inputs["input_ids"],
40
+ max_length=request.max_length,
41
+ min_length=request.min_length,
42
+ num_beams=4,
43
+ length_penalty=2.0,
44
+ early_stopping=True
45
+ )
46
+
47
+ # Decode the generated summary
48
+ translation = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
49
+
50
+ return {"summary": translation}
51
+
52
+ # Basic health check endpoint
53
+ @app.get("/health")
54
+ async def health_check():
55
+ return {"status": "healthy", "model": "facebook/bart-large-cnn"}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi>=0.95.0
2
+ uvicorn>=0.21.1
3
+ transformers>=4.27.0
4
+ torch>=2.0.0
5
+ pydantic>=1.10.7