Sofia Casadei commited on
Commit
8d6b944
Β·
1 Parent(s): aacc5eb
Files changed (2) hide show
  1. Dockerfile +1 -2
  2. mwe_whisper_flashattn.py +88 -0
Dockerfile CHANGED
@@ -11,8 +11,7 @@ COPY --from=uv /uv /uv
11
 
12
  # Install Python, pip, venv, and system dependencies
13
  RUN apt-get update && \
14
- apt-get upgrade -y && \
15
- apt-get install -y --no-install-recommends \
16
  python3.11 python3.11-venv python3-pip ffmpeg \
17
  build-essential \
18
  git \
 
11
 
12
  # Install Python, pip, venv, and system dependencies
13
  RUN apt-get update && \
14
+ apt-get install -y --fix-missing --no-install-recommends \
 
15
  python3.11 python3.11-venv python3-pip ffmpeg \
16
  build-essential \
17
  git \
mwe_whisper_flashattn.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ import numpy as np
5
+ from transformers import (
6
+ AutoModelForSpeechSeq2Seq,
7
+ AutoProcessor,
8
+ pipeline,
9
+ )
10
+ from transformers.utils import is_flash_attn_2_available
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ MODEL_ID = "openai/whisper-large-v3-turbo"
15
+ LANGUAGE = "english"
16
+
17
+ device = "cuda"
18
+ use_device_map = True
19
+ try_compile_model = True
20
+ try_use_flash_attention = True
21
+ torch_dtype = torch.float16
22
+ np_dtype = np.float16
23
+
24
+ # Initialize the model (use flash attention on cuda if possible)
25
+ try:
26
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
27
+ MODEL_ID,
28
+ torch_dtype=torch_dtype,
29
+ low_cpu_mem_usage=True,
30
+ use_safetensors=True,
31
+ attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
32
+ device_map="auto" if use_device_map else None,
33
+ )
34
+ if not use_device_map:
35
+ model.to(device)
36
+ except RuntimeError as e:
37
+ try:
38
+ logger.warning("Falling back to device_map=None")
39
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
40
+ MODEL_ID,
41
+ torch_dtype=torch_dtype,
42
+ low_cpu_mem_usage=True,
43
+ use_safetensors=True,
44
+ attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
45
+ device_map=None,
46
+ )
47
+ model.to(device)
48
+ except RuntimeError as e:
49
+ try:
50
+ logger.warning("Disabling flash attention")
51
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
52
+ MODEL_ID,
53
+ torch_dtype=torch_dtype,
54
+ low_cpu_mem_usage=True,
55
+ use_safetensors=True,
56
+ attn_implementation="sdpa",
57
+ )
58
+ model.to(device)
59
+ except Exception as e:
60
+ logger.error(f"Error loading ASR model: {e}")
61
+ logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
62
+ raise
63
+
64
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
65
+
66
+ transcribe_pipeline = pipeline(
67
+ task="automatic-speech-recognition",
68
+ model=model,
69
+ tokenizer=processor.tokenizer,
70
+ feature_extractor=processor.feature_extractor,
71
+ torch_dtype=torch_dtype
72
+ )
73
+
74
+ # Try to compile the model
75
+ try:
76
+ if try_compile_model:
77
+ transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
78
+ else:
79
+ logger.warning("Proceeding without compiling the model (requirements not met)")
80
+ except Exception as e:
81
+ logger.warning(f"Error compiling model: {e}")
82
+ logger.warning("Proceeding without compiling the model")
83
+
84
+ # Warm up the model with empty audio
85
+ logger.info("Warming up Whisper model with dummy input")
86
+ warmup_audio = np.random.rand(16000).astype(np_dtype)
87
+ transcribe_pipeline(warmup_audio)
88
+ logger.info("Model warmup complete")