Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List | |
import os | |
import requests | |
from tempfile import NamedTemporaryFile | |
from scraibe import Scraibe | |
app = FastAPI() | |
# Initialize the Scraibe model with the specified parameters | |
WHISPER_MODEL_NAME = "large-v3" | |
WHISPER_TYPE = "whisper" | |
scraibe_model = Scraibe( | |
whisper_model=WHISPER_MODEL_NAME, | |
whisper_type=WHISPER_TYPE, | |
) | |
class TranscriptionRequest(BaseModel): | |
audio_links: List[str] | |
def download_audio_from_s3(s3_url: str) -> str: | |
""" | |
Download an audio file from an S3 URL and save it locally. | |
Args: | |
s3_url (str): The S3 URL of the audio file. | |
Returns: | |
str: Path to the downloaded audio file. | |
""" | |
try: | |
response = requests.get(s3_url, stream=True) | |
response.raise_for_status() | |
# Create a temporary file to save the audio | |
temp_file = NamedTemporaryFile(delete=False, suffix=".wav") | |
with open(temp_file.name, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print(f"Downloaded audio file to {temp_file.name}") | |
return temp_file.name | |
except requests.exceptions.RequestException as e: | |
raise HTTPException(status_code=400, detail=f"Failed to download file from S3: {str(e)}") | |
async def transcribe_audio(request: TranscriptionRequest): | |
""" | |
Endpoint to transcribe audio files from S3 links. | |
Args: | |
request (TranscriptionRequest): Input data containing S3 audio links and parameters. | |
Returns: | |
dict: Transcription results. | |
""" | |
results = {} | |
try: | |
for s3_link in request.audio_links: | |
# Download the audio file from the S3 link | |
audio_path = download_audio_from_s3(s3_link) | |
# Perform the transcription | |
transcription = scraibe_model.autotranscribe( | |
audio_path, | |
) | |
# Collect the result | |
results[s3_link] = transcription | |
# Clean up the downloaded file | |
os.remove(audio_path) | |
return results | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "Welcome to the Stem Extractor API. Use the /transcribe endpoint."} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |