Sofia Casadei commited on
Commit
5c44b80
Β·
1 Parent(s): edfee48
Files changed (1) hide show
  1. main.py +50 -18
main.py CHANGED
@@ -42,30 +42,55 @@ MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
42
  LANGUAGE = os.getenv("LANGUAGE", "english").lower()
43
 
44
  device = get_device(force_cpu=False)
 
 
 
45
 
46
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
47
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
48
 
49
- attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
50
- logger.info(f"Using attention: {attention}")
51
-
52
  logger.info(f"Loading Whisper model: {MODEL_ID}")
53
  logger.info(f"Using language: {LANGUAGE}")
54
 
 
55
  try:
56
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
57
- MODEL_ID,
58
- torch_dtype=torch_dtype,
59
- low_cpu_mem_usage=True,
60
  use_safetensors=True,
61
- attn_implementation=attention,
62
- device_map="auto" if device == "cuda" else None
63
  )
64
- #model.to(device)
65
- except Exception as e:
66
- logger.error(f"Error loading ASR model: {e}")
67
- logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
68
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  processor = AutoProcessor.from_pretrained(MODEL_ID)
71
 
@@ -74,15 +99,22 @@ transcribe_pipeline = pipeline(
74
  model=model,
75
  tokenizer=processor.tokenizer,
76
  feature_extractor=processor.feature_extractor,
77
- torch_dtype=torch_dtype,
78
- #device=device,
79
  )
80
- #if device == "cuda":
81
- # transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
 
 
 
 
 
 
 
 
82
 
83
  # Warm up the model with empty audio
84
  logger.info("Warming up Whisper model with dummy input")
85
- warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
86
  transcribe_pipeline(warmup_audio)
87
  logger.info("Model warmup complete")
88
 
 
42
  LANGUAGE = os.getenv("LANGUAGE", "english").lower()
43
 
44
  device = get_device(force_cpu=False)
45
+ use_device_map = True if device == "cuda" else False
46
+ try_compile_model = True if device == "cuda" or (device == "mps" and torch.__version__ >= "2.7.0") else False
47
+ try_use_flash_attention = True if device == "cuda" and is_flash_attn_2_available() else False
48
 
49
  torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
50
  logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
51
 
 
 
 
52
  logger.info(f"Loading Whisper model: {MODEL_ID}")
53
  logger.info(f"Using language: {LANGUAGE}")
54
 
55
+ # Initialize the model (use flash attention on cuda if possible)
56
  try:
57
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
58
+ MODEL_ID,
59
+ torch_dtype=torch_dtype,
60
+ low_cpu_mem_usage=True,
61
  use_safetensors=True,
62
+ attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
63
+ device_map="auto" if use_device_map else None,
64
  )
65
+ if not use_device_map:
66
+ model.to(device)
67
+ except RuntimeError as e:
68
+ try:
69
+ logger.warning("Falling back to device_map=None")
70
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
71
+ MODEL_ID,
72
+ torch_dtype=torch_dtype,
73
+ low_cpu_mem_usage=True,
74
+ use_safetensors=True,
75
+ attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa",
76
+ device_map=None,
77
+ )
78
+ model.to(device)
79
+ except RuntimeError as e:
80
+ try:
81
+ logger.warning("Disabling flash attention")
82
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
83
+ MODEL_ID,
84
+ torch_dtype=torch_dtype,
85
+ low_cpu_mem_usage=True,
86
+ use_safetensors=True,
87
+ attn_implementation="sdpa",
88
+ )
89
+ model.to(device)
90
+ except Exception as e:
91
+ logger.error(f"Error loading ASR model: {e}")
92
+ logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
93
+ raise
94
 
95
  processor = AutoProcessor.from_pretrained(MODEL_ID)
96
 
 
99
  model=model,
100
  tokenizer=processor.tokenizer,
101
  feature_extractor=processor.feature_extractor,
102
+ torch_dtype=torch_dtype
 
103
  )
104
+
105
+ # Try to compile the model
106
+ try:
107
+ if try_compile_model:
108
+ transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
109
+ else:
110
+ logger.warning("Proceeding without compiling the model (requirements not met)")
111
+ except Exception as e:
112
+ logger.warning(f"Error compiling model: {e}")
113
+ logger.warning("Proceeding without compiling the model")
114
 
115
  # Warm up the model with empty audio
116
  logger.info("Warming up Whisper model with dummy input")
117
+ warmup_audio = np.random.rand(16000).astype(np_dtype)
118
  transcribe_pipeline(warmup_audio)
119
  logger.info("Model warmup complete")
120