promblemo / APP
aashituli's picture
Create APP
1db8d5a verified
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
import io
app = FastAPI()
# Load model and processor once
processor = AutoImageProcessor.from_pretrained("aashituli/promblemo")
model = AutoModelForImageClassification.from_pretrained("aashituli/promblemo")
# Mount templates and static files
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_class_idx = outputs.logits.argmax(-1).item()
predicted_class = model.config.id2label[predicted_class_idx]
return JSONResponse(content={"prediction": predicted_class})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)