|
import modal |
|
from fastapi import Form, HTTPException |
|
import hashlib |
|
import json |
|
from typing import Optional, Dict |
|
import gc |
|
|
|
|
|
|
|
cuda_version = "12.4.0" |
|
flavor = "devel" |
|
operating_sys = "ubuntu22.04" |
|
tag = f"{cuda_version}-{flavor}-{operating_sys}" |
|
|
|
image = ( |
|
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11") |
|
.apt_install( |
|
"git", |
|
"ffmpeg", |
|
"libcudnn8", |
|
"libcudnn8-dev", |
|
) |
|
.pip_install( |
|
"fastapi[standard]", |
|
"httpx", |
|
"torch==2.0.0", |
|
"torchaudio==2.0.0", |
|
"numpy<2.0", |
|
extra_index_url="https://download.pytorch.org/whl/cu118", |
|
) |
|
.pip_install( |
|
"git+https://github.com/m-bain/[email protected]", |
|
"ffmpeg-python", |
|
"ctranslate2==4.4.0", |
|
) |
|
) |
|
app = modal.App("whisperx-api", image=image) |
|
|
|
GPU_CONFIG = "L4" |
|
|
|
CACHE_DIR = "/cache" |
|
cache_vol = modal.Volume.from_name("whisper-cache", create_if_missing=True) |
|
|
|
@app.cls( |
|
gpu=GPU_CONFIG, |
|
volumes={CACHE_DIR: cache_vol}, |
|
scaledown_window=60 * 10, |
|
timeout=60 * 60, |
|
) |
|
@modal.concurrent(max_inputs=15) |
|
class Model: |
|
@modal.enter() |
|
def setup(self): |
|
import whisperx |
|
|
|
device = "cuda" |
|
compute_type = ( |
|
"float16" |
|
) |
|
|
|
self.model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=CACHE_DIR) |
|
|
|
@modal.method() |
|
def transcribe(self, audio_url: str): |
|
import requests |
|
import whisperx |
|
|
|
batch_size = 16 |
|
|
|
response = requests.get(audio_url) |
|
audio_path = "downloaded_audio.wav" |
|
with open(audio_path, "wb") as audio_file: |
|
audio_file.write(response.content) |
|
|
|
audio = whisperx.load_audio(audio_path) |
|
|
|
result = self.model.transcribe(audio, batch_size=batch_size) |
|
print("Initial transcription result:", result["segments"][:1]) |
|
|
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device="cuda") |
|
|
|
aligned_result = whisperx.align( |
|
result["segments"], |
|
model_a, |
|
metadata, |
|
audio_path, |
|
device="cuda", |
|
return_char_alignments=False |
|
) |
|
import torch |
|
del model_a |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
results = { |
|
"language": result["language"], |
|
"language_probability": result.get("language_probability", 1.0), |
|
"words": [] |
|
} |
|
|
|
for segment in aligned_result.get("segments", []): |
|
if "words" in segment: |
|
for word in segment.get("words", []): |
|
word_data = { |
|
"start": word.get("start", word.get("start_time", 0.0)), |
|
"end": word.get("end", word.get("end_time", 0.0)), |
|
"word": word.get("word", "") |
|
} |
|
if word_data.get("word", None): |
|
results["words"].append(word_data) |
|
|
|
return results |
|
|
|
|
|
@app.function() |
|
@modal.fastapi_endpoint(docs=True, method="POST") |
|
async def transcribe_endpoint(url: str = Form(...)): |
|
if not url.startswith(("http://", "https://")): |
|
raise HTTPException(status_code=400, detail="URL must start with http:// or https://") |
|
return Model().transcribe.remote(audio_url=url) |
|
|
|
|
|
@app.local_entrypoint() |
|
def main(): |
|
url = "https://pub-ebe9e51393584bf5b5bea84a67b343c2.r2.dev/examples_english_english.wav" |
|
print(Model().transcribe.remote(url)) |
|
|
|
|