from fastapi import FastAPI, File, UploadFile, Request from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from PIL import Image import torch from transformers import AutoImageProcessor, AutoModelForImageClassification import io app = FastAPI() # Load model and processor once processor = AutoImageProcessor.from_pretrained("aashituli/promblemo") model = AutoModelForImageClassification.from_pretrained("aashituli/promblemo") # Mount templates and static files app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") @app.get("/", response_class=HTMLResponse) async def home(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/predict/") async def predict(file: UploadFile = File(...)): try: contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) predicted_class_idx = outputs.logits.argmax(-1).item() predicted_class = model.config.id2label[predicted_class_idx] return JSONResponse(content={"prediction": predicted_class}) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500)