from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List import os import requests from tempfile import NamedTemporaryFile from scraibe import Scraibe app = FastAPI() # Initialize the Scraibe model with the specified parameters WHISPER_MODEL_NAME = "large-v3" WHISPER_TYPE = "whisper" scraibe_model = Scraibe( whisper_model=WHISPER_MODEL_NAME, whisper_type=WHISPER_TYPE, ) class TranscriptionRequest(BaseModel): audio_links: List[str] def download_audio_from_s3(s3_url: str) -> str: """ Download an audio file from an S3 URL and save it locally. Args: s3_url (str): The S3 URL of the audio file. Returns: str: Path to the downloaded audio file. """ try: response = requests.get(s3_url, stream=True) response.raise_for_status() # Create a temporary file to save the audio temp_file = NamedTemporaryFile(delete=False, suffix=".wav") with open(temp_file.name, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Downloaded audio file to {temp_file.name}") return temp_file.name except requests.exceptions.RequestException as e: raise HTTPException(status_code=400, detail=f"Failed to download file from S3: {str(e)}") @app.post("/transcribe") async def transcribe_audio(request: TranscriptionRequest): """ Endpoint to transcribe audio files from S3 links. Args: request (TranscriptionRequest): Input data containing S3 audio links and parameters. Returns: dict: Transcription results. """ results = {} try: for s3_link in request.audio_links: # Download the audio file from the S3 link audio_path = download_audio_from_s3(s3_link) # Perform the transcription transcription = scraibe_model.autotranscribe( audio_path, ) # Collect the result results[s3_link] = transcription # Clean up the downloaded file os.remove(audio_path) return results except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): return {"message": "Welcome to the Stem Extractor API. Use the /transcribe endpoint."} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)