root
remove peload
58ffba8
from fastapi import FastAPI, File, UploadFile, responses
import logging
logging.basicConfig(level=logging.INFO)
from fastapi.responses import RedirectResponse
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
import torch
import pytorchvideo.data
from pytorchvideo.transforms import (
ApplyTransformToKey,
Normalize,
UniformTemporalSubsample,
Lambda,
Resize,
)
from torchvision.transforms import Compose
import os
# Initialize the FastAPI app instance
app = FastAPI()
# Load the pre-trained model and image processor during startup
model_path = './checkpoint-450'
trained_model = VideoMAEForVideoClassification.from_pretrained(model_path)
image_processor = VideoMAEImageProcessor.from_pretrained(model_path)
trained_model.eval()
# Set up the transform for validation/inference
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
height = width = image_processor.size["shortest_edge"]
else:
height = image_processor.size["height"]
width = image_processor.size["width"]
resize_to = (height, width)
num_frames_to_sample = trained_model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps
val_transform = Compose(
[
ApplyTransformToKey(
key="video",
transform=Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
Resize(resize_to),
]
),
),
]
)
# Define the inference function
# @app.on_event("startup")
# def load_model():
# global trained_model
# trained_model = trained_model.to(trained_model.device)
def run_inference(video):
"""Utility to run inference given a test video.
The video is assumed to be preprocessed already.
"""
try:
# (num_frames, num_channels, height, width)
perumuted_sample_test_video = video.permute(1, 0, 2, 3)
inputs = {
"pixel_values": perumuted_sample_test_video.unsqueeze(0),
}
# forward pass
with torch.no_grad():
outputs = trained_model(**inputs)
logits = outputs.logits
return logits
except Exception as e:
logging.error(f"Error during inference: {e}")
return None
# Define the API endpoint
@app.get("/")
async def roottransfer():
return RedirectResponse(url="/docs")
@app.post("/predict")
async def predict(video: UploadFile = File(...)):
# Read the video file
video_bytes = await video.read()
# Preprocess the video
sample_test_video = val_transform({"video": pytorchvideo.data.encoded_video_from_bytes(video_bytes)})
# Run inference
logits = run_inference(sample_test_video["video"])
# Get the predicted class label
predicted_class_idx = logits.argmax(-1).item()
predicted_label = trained_model.config.id2label[predicted_class_idx]
return {"predicted_class": predicted_label}
if __name__ == "__main__":
uvicorn.run(app, port=8000)