File size: 2,763 Bytes
8d6b944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import torch

import numpy as np
from transformers import (
    AutoModelForSpeechSeq2Seq,
    AutoProcessor,
    pipeline,
)
from transformers.utils import is_flash_attn_2_available

logger = logging.getLogger(__name__)

MODEL_ID = "openai/whisper-large-v3-turbo"
LANGUAGE = "english"

device = "cuda"
use_device_map = True
try_compile_model = True
try_use_flash_attention = True
torch_dtype = torch.float16
np_dtype = np.float16

# Initialize the model (use flash attention on cuda if possible)
try:
    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        MODEL_ID,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True,
        attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
        device_map="auto" if use_device_map else None,
    )
    if not use_device_map:
        model.to(device)
except RuntimeError as e:
    try:
        logger.warning("Falling back to device_map=None")
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            MODEL_ID,
            torch_dtype=torch_dtype,
            low_cpu_mem_usage=True,
            use_safetensors=True,
            attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
            device_map=None,
        )
        model.to(device)
    except RuntimeError as e:
        try:
            logger.warning("Disabling flash attention")
            model = AutoModelForSpeechSeq2Seq.from_pretrained(
                MODEL_ID,
                torch_dtype=torch_dtype,
                low_cpu_mem_usage=True,
                use_safetensors=True,
                attn_implementation="sdpa",
            )
            model.to(device)
        except Exception as e:
            logger.error(f"Error loading ASR model: {e}")
            logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
            raise

processor = AutoProcessor.from_pretrained(MODEL_ID)

transcribe_pipeline = pipeline(
    task="automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype
)

# Try to compile the model
try:
    if try_compile_model:
        transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
    else:
        logger.warning("Proceeding without compiling the model (requirements not met)")
except Exception as e:
    logger.warning(f"Error compiling model: {e}")
    logger.warning("Proceeding without compiling the model")

# Warm up the model with empty audio
logger.info("Warming up Whisper model with dummy input")
warmup_audio = np.random.rand(16000).astype(np_dtype)
transcribe_pipeline(warmup_audio)
logger.info("Model warmup complete")