Spaces:
Sleeping
Sleeping
File size: 2,474 Bytes
e11256b ea89d31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import os
import requests
from tempfile import NamedTemporaryFile
from scraibe import Scraibe
app = FastAPI()
# Initialize the Scraibe model with the specified parameters
WHISPER_MODEL_NAME = "large-v3"
WHISPER_TYPE = "whisper"
scraibe_model = Scraibe(
whisper_model=WHISPER_MODEL_NAME,
whisper_type=WHISPER_TYPE,
)
class TranscriptionRequest(BaseModel):
audio_links: List[str]
def download_audio_from_s3(s3_url: str) -> str:
"""
Download an audio file from an S3 URL and save it locally.
Args:
s3_url (str): The S3 URL of the audio file.
Returns:
str: Path to the downloaded audio file.
"""
try:
response = requests.get(s3_url, stream=True)
response.raise_for_status()
# Create a temporary file to save the audio
temp_file = NamedTemporaryFile(delete=False, suffix=".wav")
with open(temp_file.name, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded audio file to {temp_file.name}")
return temp_file.name
except requests.exceptions.RequestException as e:
raise HTTPException(status_code=400, detail=f"Failed to download file from S3: {str(e)}")
@app.post("/transcribe")
async def transcribe_audio(request: TranscriptionRequest):
"""
Endpoint to transcribe audio files from S3 links.
Args:
request (TranscriptionRequest): Input data containing S3 audio links and parameters.
Returns:
dict: Transcription results.
"""
results = {}
try:
for s3_link in request.audio_links:
# Download the audio file from the S3 link
audio_path = download_audio_from_s3(s3_link)
# Perform the transcription
transcription = scraibe_model.autotranscribe(
audio_path,
)
# Collect the result
results[s3_link] = transcription
# Clean up the downloaded file
os.remove(audio_path)
return results
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Welcome to the Stem Extractor API. Use the /transcribe endpoint."}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|