video_editing_genie_mcp_server / deploy_whisper_on_modal.py
MalikIbrar's picture
fixes
462d0d3
import modal
from fastapi import Form, HTTPException
import hashlib
import json
from typing import Optional, Dict
import gc
cuda_version = "12.4.0" # should be no greater than host CUDA version
flavor = "devel" # includes full CUDA toolkit
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11")
.apt_install(
"git",
"ffmpeg",
"libcudnn8",
"libcudnn8-dev",
)
.pip_install(
"fastapi[standard]",
"httpx",
"torch==2.0.0",
"torchaudio==2.0.0",
"numpy<2.0",
extra_index_url="https://download.pytorch.org/whl/cu118",
)
.pip_install(
"git+https://github.com/m-bain/[email protected]",
"ffmpeg-python",
"ctranslate2==4.4.0",
)
)
app = modal.App("whisperx-api", image=image)
GPU_CONFIG = "L4"
CACHE_DIR = "/cache"
cache_vol = modal.Volume.from_name("whisper-cache", create_if_missing=True)
@app.cls(
gpu=GPU_CONFIG,
volumes={CACHE_DIR: cache_vol},
scaledown_window=60 * 10,
timeout=60 * 60,
)
@modal.concurrent(max_inputs=15)
class Model:
@modal.enter()
def setup(self):
import whisperx
device = "cuda"
compute_type = (
"float16"
)
self.model = whisperx.load_model("large-v2", device, compute_type=compute_type, download_root=CACHE_DIR)
@modal.method()
def transcribe(self, audio_url: str):
import requests
import whisperx
batch_size = 16
response = requests.get(audio_url)
audio_path = "downloaded_audio.wav"
with open(audio_path, "wb") as audio_file:
audio_file.write(response.content)
audio = whisperx.load_audio(audio_path)
result = self.model.transcribe(audio, batch_size=batch_size)
print("Initial transcription result:", result["segments"][:1])
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device="cuda")
aligned_result = whisperx.align(
result["segments"],
model_a,
metadata,
audio_path,
device="cuda",
return_char_alignments=False
)
import torch
del model_a
gc.collect()
torch.cuda.empty_cache()
results = {
"language": result["language"],
"language_probability": result.get("language_probability", 1.0),
"words": []
}
for segment in aligned_result.get("segments", []):
if "words" in segment:
for word in segment.get("words", []):
word_data = {
"start": word.get("start", word.get("start_time", 0.0)),
"end": word.get("end", word.get("end_time", 0.0)),
"word": word.get("word", "")
}
if word_data.get("word", None):
results["words"].append(word_data)
return results
@app.function()
@modal.fastapi_endpoint(docs=True, method="POST")
async def transcribe_endpoint(url: str = Form(...)):
if not url.startswith(("http://", "https://")):
raise HTTPException(status_code=400, detail="URL must start with http:// or https://")
return Model().transcribe.remote(audio_url=url)
# ## Run the model
@app.local_entrypoint()
def main():
url = "https://pub-ebe9e51393584bf5b5bea84a67b343c2.r2.dev/examples_english_english.wav"
print(Model().transcribe.remote(url))