Spaces:
Build error
Build error
from fastapi import FastAPI, File, UploadFile, responses | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
from fastapi.responses import RedirectResponse | |
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification | |
import torch | |
import pytorchvideo.data | |
from pytorchvideo.transforms import ( | |
ApplyTransformToKey, | |
Normalize, | |
UniformTemporalSubsample, | |
Lambda, | |
Resize, | |
) | |
from torchvision.transforms import Compose | |
import os | |
# Initialize the FastAPI app instance | |
app = FastAPI() | |
# Load the pre-trained model and image processor during startup | |
model_path = './checkpoint-450' | |
trained_model = VideoMAEForVideoClassification.from_pretrained(model_path) | |
image_processor = VideoMAEImageProcessor.from_pretrained(model_path) | |
trained_model.eval() | |
# Set up the transform for validation/inference | |
mean = image_processor.image_mean | |
std = image_processor.image_std | |
if "shortest_edge" in image_processor.size: | |
height = width = image_processor.size["shortest_edge"] | |
else: | |
height = image_processor.size["height"] | |
width = image_processor.size["width"] | |
resize_to = (height, width) | |
num_frames_to_sample = trained_model.config.num_frames | |
sample_rate = 4 | |
fps = 30 | |
clip_duration = num_frames_to_sample * sample_rate / fps | |
val_transform = Compose( | |
[ | |
ApplyTransformToKey( | |
key="video", | |
transform=Compose( | |
[ | |
UniformTemporalSubsample(num_frames_to_sample), | |
Lambda(lambda x: x / 255.0), | |
Normalize(mean, std), | |
Resize(resize_to), | |
] | |
), | |
), | |
] | |
) | |
# Define the inference function | |
# @app.on_event("startup") | |
# def load_model(): | |
# global trained_model | |
# trained_model = trained_model.to(trained_model.device) | |
def run_inference(video): | |
"""Utility to run inference given a test video. | |
The video is assumed to be preprocessed already. | |
""" | |
try: | |
# (num_frames, num_channels, height, width) | |
perumuted_sample_test_video = video.permute(1, 0, 2, 3) | |
inputs = { | |
"pixel_values": perumuted_sample_test_video.unsqueeze(0), | |
} | |
# forward pass | |
with torch.no_grad(): | |
outputs = trained_model(**inputs) | |
logits = outputs.logits | |
return logits | |
except Exception as e: | |
logging.error(f"Error during inference: {e}") | |
return None | |
# Define the API endpoint | |
async def roottransfer(): | |
return RedirectResponse(url="/docs") | |
async def predict(video: UploadFile = File(...)): | |
# Read the video file | |
video_bytes = await video.read() | |
# Preprocess the video | |
sample_test_video = val_transform({"video": pytorchvideo.data.encoded_video_from_bytes(video_bytes)}) | |
# Run inference | |
logits = run_inference(sample_test_video["video"]) | |
# Get the predicted class label | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_label = trained_model.config.id2label[predicted_class_idx] | |
return {"predicted_class": predicted_label} | |
if __name__ == "__main__": | |
uvicorn.run(app, port=8000) |