aashituli commited on
Commit
1db8d5a
·
verified ·
1 Parent(s): 8538457

Create APP

Browse files
Files changed (1) hide show
  1. APP +40 -0
APP ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Request
2
+ from fastapi.responses import HTMLResponse, JSONResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.templating import Jinja2Templates
5
+ from PIL import Image
6
+ import torch
7
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
+ import io
9
+
10
+ app = FastAPI()
11
+
12
+ # Load model and processor once
13
+ processor = AutoImageProcessor.from_pretrained("aashituli/promblemo")
14
+ model = AutoModelForImageClassification.from_pretrained("aashituli/promblemo")
15
+
16
+ # Mount templates and static files
17
+ app.mount("/static", StaticFiles(directory="static"), name="static")
18
+ templates = Jinja2Templates(directory="templates")
19
+
20
+ @app.get("/", response_class=HTMLResponse)
21
+ async def home(request: Request):
22
+ return templates.TemplateResponse("index.html", {"request": request})
23
+
24
+ @app.post("/predict/")
25
+ async def predict(file: UploadFile = File(...)):
26
+ try:
27
+ contents = await file.read()
28
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
29
+ inputs = processor(images=image, return_tensors="pt")
30
+
31
+ with torch.no_grad():
32
+ outputs = model(**inputs)
33
+
34
+ predicted_class_idx = outputs.logits.argmax(-1).item()
35
+ predicted_class = model.config.id2label[predicted_class_idx]
36
+
37
+ return JSONResponse(content={"prediction": predicted_class})
38
+
39
+ except Exception as e:
40
+ return JSONResponse(content={"error": str(e)}, status_code=500)