Stem-Extractor / app.py
samarth-ht's picture
Update app.py
ea89d31 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import os
import requests
from tempfile import NamedTemporaryFile
from scraibe import Scraibe
app = FastAPI()
# Initialize the Scraibe model with the specified parameters
WHISPER_MODEL_NAME = "large-v3"
WHISPER_TYPE = "whisper"
scraibe_model = Scraibe(
whisper_model=WHISPER_MODEL_NAME,
whisper_type=WHISPER_TYPE,
)
class TranscriptionRequest(BaseModel):
audio_links: List[str]
def download_audio_from_s3(s3_url: str) -> str:
"""
Download an audio file from an S3 URL and save it locally.
Args:
s3_url (str): The S3 URL of the audio file.
Returns:
str: Path to the downloaded audio file.
"""
try:
response = requests.get(s3_url, stream=True)
response.raise_for_status()
# Create a temporary file to save the audio
temp_file = NamedTemporaryFile(delete=False, suffix=".wav")
with open(temp_file.name, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded audio file to {temp_file.name}")
return temp_file.name
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=400, detail=f"Failed to download file from S3: {str(e)}")
@app.post("/transcribe")
async def transcribe_audio(request: TranscriptionRequest):
"""
Endpoint to transcribe audio files from S3 links.
Args:
request (TranscriptionRequest): Input data containing S3 audio links and parameters.
Returns:
dict: Transcription results.
"""
results = {}
try:
for s3_link in request.audio_links:
# Download the audio file from the S3 link
audio_path = download_audio_from_s3(s3_link)
# Perform the transcription
transcription = scraibe_model.autotranscribe(
audio_path,
)
# Collect the result
results[s3_link] = transcription
# Clean up the downloaded file
os.remove(audio_path)
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Welcome to the Stem Extractor API. Use the /transcribe endpoint."}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)