File size: 2,474 Bytes
e11256b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea89d31
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import os
import requests
from tempfile import NamedTemporaryFile
from scraibe import Scraibe

app = FastAPI()

# Initialize the Scraibe model with the specified parameters
WHISPER_MODEL_NAME = "large-v3"
WHISPER_TYPE = "whisper"

scraibe_model = Scraibe(
    whisper_model=WHISPER_MODEL_NAME,
    whisper_type=WHISPER_TYPE,
)

class TranscriptionRequest(BaseModel):
    audio_links: List[str]


def download_audio_from_s3(s3_url: str) -> str:
    """
    Download an audio file from an S3 URL and save it locally.

    Args:
        s3_url (str): The S3 URL of the audio file.

    Returns:
        str: Path to the downloaded audio file.
    """
    try:
        response = requests.get(s3_url, stream=True)
        response.raise_for_status()

        # Create a temporary file to save the audio
        temp_file = NamedTemporaryFile(delete=False, suffix=".wav")
        with open(temp_file.name, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"Downloaded audio file to {temp_file.name}")

        return temp_file.name

    except requests.exceptions.RequestException as e:
        raise HTTPException(status_code=400, detail=f"Failed to download file from S3: {str(e)}")

@app.post("/transcribe")
async def transcribe_audio(request: TranscriptionRequest):
    """
    Endpoint to transcribe audio files from S3 links.

    Args:
        request (TranscriptionRequest): Input data containing S3 audio links and parameters.

    Returns:
        dict: Transcription results.
    """
    results = {}

    try:
        for s3_link in request.audio_links:
            # Download the audio file from the S3 link
            audio_path = download_audio_from_s3(s3_link)

            # Perform the transcription
            transcription = scraibe_model.autotranscribe(
                audio_path,
            )

            # Collect the result
            results[s3_link] = transcription

            # Clean up the downloaded file
            os.remove(audio_path)

        return results

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "Welcome to the Stem Extractor API. Use the /transcribe endpoint."}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)