from fastapi import FastAPI, File, UploadFile, responses import logging logging.basicConfig(level=logging.INFO) from fastapi.responses import RedirectResponse from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification import torch import pytorchvideo.data from pytorchvideo.transforms import ( ApplyTransformToKey, Normalize, UniformTemporalSubsample, Lambda, Resize, ) from torchvision.transforms import Compose import os # Initialize the FastAPI app instance app = FastAPI() # Load the pre-trained model and image processor during startup model_path = './checkpoint-450' trained_model = VideoMAEForVideoClassification.from_pretrained(model_path) image_processor = VideoMAEImageProcessor.from_pretrained(model_path) trained_model.eval() # Set up the transform for validation/inference mean = image_processor.image_mean std = image_processor.image_std if "shortest_edge" in image_processor.size: height = width = image_processor.size["shortest_edge"] else: height = image_processor.size["height"] width = image_processor.size["width"] resize_to = (height, width) num_frames_to_sample = trained_model.config.num_frames sample_rate = 4 fps = 30 clip_duration = num_frames_to_sample * sample_rate / fps val_transform = Compose( [ ApplyTransformToKey( key="video", transform=Compose( [ UniformTemporalSubsample(num_frames_to_sample), Lambda(lambda x: x / 255.0), Normalize(mean, std), Resize(resize_to), ] ), ), ] ) # Define the inference function # @app.on_event("startup") # def load_model(): # global trained_model # trained_model = trained_model.to(trained_model.device) def run_inference(video): """Utility to run inference given a test video. The video is assumed to be preprocessed already. """ try: # (num_frames, num_channels, height, width) perumuted_sample_test_video = video.permute(1, 0, 2, 3) inputs = { "pixel_values": perumuted_sample_test_video.unsqueeze(0), } # forward pass with torch.no_grad(): outputs = trained_model(**inputs) logits = outputs.logits return logits except Exception as e: logging.error(f"Error during inference: {e}") return None # Define the API endpoint @app.get("/") async def roottransfer(): return RedirectResponse(url="/docs") @app.post("/predict") async def predict(video: UploadFile = File(...)): # Read the video file video_bytes = await video.read() # Preprocess the video sample_test_video = val_transform({"video": pytorchvideo.data.encoded_video_from_bytes(video_bytes)}) # Run inference logits = run_inference(sample_test_video["video"]) # Get the predicted class label predicted_class_idx = logits.argmax(-1).item() predicted_label = trained_model.config.id2label[predicted_class_idx] return {"predicted_class": predicted_label} if __name__ == "__main__": uvicorn.run(app, port=8000)