farzadab commited on
Commit
15893c0
·
verified ·
1 Parent(s): a4cd5e8

Update ultravox_model.py

Browse files
Files changed (1) hide show
  1. ultravox_model.py +112 -43
ultravox_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  import re
3
- from typing import Any, Dict, Generator, Optional, Set, Tuple, Union
4
 
5
  import peft
6
  import torch
@@ -56,6 +56,11 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
56
  self.multi_modal_projector = self._create_multi_modal_projector(config)
57
  self.language_model = self._create_language_model(config)
58
 
 
 
 
 
 
59
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
60
  # FSDP throws an error if some of the layer types are not found in the model.
61
  # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
@@ -64,6 +69,39 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
64
  self.loss_config = LossConfig()
65
  self.post_init()
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def get_input_embeddings(self):
68
  return self.language_model.get_input_embeddings()
69
 
@@ -110,6 +148,30 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
110
  self.vocab_size = model_embeds.num_embeddings
111
  return model_embeds
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def _compute_kl_loss(
114
  self,
115
  lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
@@ -134,11 +196,12 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
134
  # compute the KL divergence loss between the two models
135
  kl_loss = F.kl_div(
136
  F.log_softmax(
137
- lm_output.logits[labels != -100] / self.loss_config.kl_temperature,
 
138
  dim=-1,
139
  ),
140
  F.softmax(
141
- alt_lm_output.logits[alt_labels != -100]
142
  / self.loss_config.kl_temperature,
143
  dim=-1,
144
  ),
@@ -289,7 +352,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
289
 
290
  # include audio information in model_input only when it is needed during prefilling
291
  # audio_token_start_idx should always be relative to the current cache position
292
- prefill_start_idx = 0 if cache_position is None else cache_position[0]
 
 
293
  if (
294
  audio_values is not None
295
  and audio_token_start_idx is not None
@@ -317,23 +382,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
317
  def _create_audio_tower(
318
  cls, config: UltravoxConfig
319
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
320
- if config.audio_model_id is not None:
321
- if "whisper" in config.audio_model_id.lower():
322
- audio_tower = ModifiedWhisperEncoder.from_pretrained(
323
- config.audio_model_id, torch_dtype=config.torch_dtype
324
- )
325
- audio_tower.init_latency_mask(
326
- config.audio_latency_block_size, dtype=config.torch_dtype
327
- )
328
- else:
329
- assert config.audio_latency_block_size in (
330
- None,
331
- 0,
332
- ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
333
- audio_tower = transformers.AutoModel.from_pretrained(
334
- config.audio_model_id, torch_dtype=config.torch_dtype
335
- )
336
- else:
337
  if "whisper" in config.audio_config._name_or_path.lower():
338
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
339
  audio_tower.init_latency_mask(
@@ -344,12 +395,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
344
  None,
345
  0,
346
  ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
347
- with transformers.modeling_utils.no_init_weights():
348
- # we only ever use from_config if the weights are retrained, hence initializing is not
349
- # required. This makes the model quite creation faster since init on CPU is quite slow.
350
- audio_tower = transformers.AutoModel.from_config(
351
- config.audio_config
352
- )
353
 
354
  if isinstance(
355
  audio_tower,
@@ -367,21 +413,14 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
367
  def _create_language_model(
368
  cls, config: UltravoxConfig
369
  ) -> transformers.LlamaForCausalLM:
370
- if config.text_model_id is not None:
371
- language_model = transformers.AutoModelForCausalLM.from_pretrained(
372
- config.text_model_id,
373
- attn_implementation=config._attn_implementation,
 
 
374
  torch_dtype=config.torch_dtype,
375
  )
376
- else:
377
- with transformers.modeling_utils.no_init_weights():
378
- # we only ever use from_config if the weights are retrained, hence initializing is not
379
- # required. This makes the model quite creation faster since init on CPU is quite slow.
380
- language_model = transformers.AutoModelForCausalLM.from_config(
381
- config.text_config,
382
- attn_implementation=config._attn_implementation,
383
- torch_dtype=config.torch_dtype,
384
- )
385
 
386
  language_model = apply_lora(language_model, config.text_model_lora_config)
387
  return language_model
@@ -495,7 +534,10 @@ def is_cache_empty(
495
  return past_key_values.get_seq_length() == 0
496
 
497
 
498
- def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
 
 
 
499
  """
500
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
501
  """
@@ -574,11 +616,35 @@ class UltravoxProjector(nn.Module):
574
  self.ln_post = RMSNorm(dim_out, init=config.norm_init)
575
 
576
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  audio_features = self._pad_and_stack(audio_features)
578
  audio_features = self.ln_pre(audio_features)
 
579
  hidden_states = self.linear_1(audio_features)
 
580
  hidden_states = self.act(hidden_states)
581
  hidden_states = self.ln_mid(hidden_states)
 
582
  hidden_states = self.linear_2(hidden_states)
583
  hidden_states = self.ln_post(hidden_states)
584
  return hidden_states
@@ -601,6 +667,7 @@ class ModifiedWhisperEncoder(
601
 
602
  base_model_prefix = "model.encoder"
603
  _no_split_modules = ["WhisperEncoderLayer"]
 
604
 
605
  def __init__(self, config: transformers.WhisperConfig):
606
  super().__init__(config)
@@ -614,7 +681,9 @@ class ModifiedWhisperEncoder(
614
  * self.conv2.stride[0]
615
  )
616
 
617
- def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype):
 
 
618
  if audio_latency_block_size is None:
619
  self.audio_streaming_mask = None
620
  return
@@ -781,4 +850,4 @@ UltravoxModel.register_for_auto_class()
781
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
782
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
783
 
784
- transformers.activations.ACT2FN["swiglu"] = SwiGLU
 
1
  import logging
2
  import re
3
+ from typing import Any, Dict, Generator, Optional, Set, Tuple, TypeVar, Union
4
 
5
  import peft
6
  import torch
 
56
  self.multi_modal_projector = self._create_multi_modal_projector(config)
57
  self.language_model = self._create_language_model(config)
58
 
59
+ if self.language_model._tied_weights_keys is not None:
60
+ self._tied_weights_keys = [
61
+ f"language_model.{k}" for k in self.language_model._tied_weights_keys
62
+ ]
63
+
64
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
65
  # FSDP throws an error if some of the layer types are not found in the model.
66
  # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
 
69
  self.loss_config = LossConfig()
70
  self.post_init()
71
 
72
+ @classmethod
73
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
74
+ model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
75
+ model._load_child_model_weights(*args, **kwargs)
76
+ return model
77
+
78
+ def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
79
+ if (
80
+ self.config.text_model_id is not None
81
+ and self.language_model.device.type == "meta"
82
+ ):
83
+ # Load the language model weights
84
+ self.language_model = transformers.AutoModelForCausalLM.from_pretrained(
85
+ self.config.text_model_id,
86
+ torch_dtype=self.config.torch_dtype,
87
+ *args,
88
+ **kwargs,
89
+ )
90
+
91
+ if (
92
+ self.config.audio_model_id is not None
93
+ and self.audio_tower.device.type == "meta"
94
+ ):
95
+ # Load the audio tower weights
96
+ self.audio_tower = transformers.AutoModel.from_pretrained(
97
+ self.config.audio_model_id,
98
+ torch_dtype=self.config.torch_dtype,
99
+ *args,
100
+ **kwargs,
101
+ )
102
+
103
+ return self
104
+
105
  def get_input_embeddings(self):
106
  return self.language_model.get_input_embeddings()
107
 
 
148
  self.vocab_size = model_embeds.num_embeddings
149
  return model_embeds
150
 
151
+ def _get_prediction_mask(self, labels: Optional[torch.Tensor]) -> torch.Tensor:
152
+ """Get a boolean mask for positions where we want to compute KL divergence.
153
+
154
+ For each label position, we want the position before it since that's where
155
+ the model makes the prediction for that label.
156
+
157
+ Args:
158
+ labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
159
+ with -100 for masked positions and token ids for label positions
160
+
161
+ Returns:
162
+ Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
163
+ """
164
+ if labels is None:
165
+ raise ValueError("labels must be provided")
166
+ # Shift the label mask right by 1 along the sequence dimension
167
+ # This gives us positions where we make predictions for the next token
168
+ label_mask = labels != -100
169
+ pred_mask = torch.zeros_like(label_mask)
170
+ pred_mask[:, :-1] = label_mask[
171
+ :, 1:
172
+ ] # shift right by 1 along sequence dimension
173
+ return pred_mask
174
+
175
  def _compute_kl_loss(
176
  self,
177
  lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
 
196
  # compute the KL divergence loss between the two models
197
  kl_loss = F.kl_div(
198
  F.log_softmax(
199
+ lm_output.logits[self._get_prediction_mask(labels)]
200
+ / self.loss_config.kl_temperature,
201
  dim=-1,
202
  ),
203
  F.softmax(
204
+ alt_lm_output.logits[self._get_prediction_mask(alt_labels)]
205
  / self.loss_config.kl_temperature,
206
  dim=-1,
207
  ),
 
352
 
353
  # include audio information in model_input only when it is needed during prefilling
354
  # audio_token_start_idx should always be relative to the current cache position
355
+ prefill_start_idx: int | torch.Tensor = (
356
+ 0 if cache_position is None else cache_position[0]
357
+ )
358
  if (
359
  audio_values is not None
360
  and audio_token_start_idx is not None
 
382
  def _create_audio_tower(
383
  cls, config: UltravoxConfig
384
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
385
+ with transformers.modeling_utils.no_init_weights():
386
+ # we only ever use from_config if the weights are retrained, hence initializing is not
387
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  if "whisper" in config.audio_config._name_or_path.lower():
389
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
390
  audio_tower.init_latency_mask(
 
395
  None,
396
  0,
397
  ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
398
+ audio_tower = transformers.AutoModel.from_config(config.audio_config)
 
 
 
 
 
399
 
400
  if isinstance(
401
  audio_tower,
 
413
  def _create_language_model(
414
  cls, config: UltravoxConfig
415
  ) -> transformers.LlamaForCausalLM:
416
+ with transformers.modeling_utils.no_init_weights():
417
+ # we only ever use from_config if the weights are retrained, hence initializing is not
418
+ # required. This makes the model quite creation faster since init on CPU is quite slow.
419
+ language_model = transformers.AutoModelForCausalLM.from_config(
420
+ config.text_config,
421
+ attn_implementation=config.text_config._attn_implementation,
422
  torch_dtype=config.torch_dtype,
423
  )
 
 
 
 
 
 
 
 
 
424
 
425
  language_model = apply_lora(language_model, config.text_model_lora_config)
426
  return language_model
 
534
  return past_key_values.get_seq_length() == 0
535
 
536
 
537
+ T = TypeVar("T", bound=torch.nn.Module)
538
+
539
+
540
+ def apply_lora(model: T, lora_config: dict) -> T:
541
  """
542
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
543
  """
 
616
  self.ln_post = RMSNorm(dim_out, init=config.norm_init)
617
 
618
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
619
+ """
620
+ Takes in audio features from the audio tower and projects them to the text model's embedding space.
621
+ It reduces the number of frames by a factor of `stack_factor` and increases the number of channels by the same factor.
622
+ If the number of audio frames are not a multiple of the stack factor, the last few frames will be padded with zeros.
623
+
624
+ Input shape:
625
+ audio_features: B, T*S, C
626
+ Output shape:
627
+ hidden_states: B, T, D
628
+ Where:
629
+ B: batch size
630
+ F: number of frames in the audio tower
631
+ T: number of output embeddings
632
+ T = ceil(F / S)
633
+ S: stack factor
634
+ C: number of channels out of the encoder (aka audio tower)
635
+ H: hidden size of the projector (config.hidden_size)
636
+ D: dimension of the text model (config.text_config.hidden_size)
637
+
638
+ """
639
+ # B, F, C -> B, T, C*S
640
  audio_features = self._pad_and_stack(audio_features)
641
  audio_features = self.ln_pre(audio_features)
642
+ # B, T, C*S -> B, T, H
643
  hidden_states = self.linear_1(audio_features)
644
+ # B, T, H -> B, T, H/2 (assuming swiglu)
645
  hidden_states = self.act(hidden_states)
646
  hidden_states = self.ln_mid(hidden_states)
647
+ # B, T, H/2 -> B, T, D
648
  hidden_states = self.linear_2(hidden_states)
649
  hidden_states = self.ln_post(hidden_states)
650
  return hidden_states
 
667
 
668
  base_model_prefix = "model.encoder"
669
  _no_split_modules = ["WhisperEncoderLayer"]
670
+ _keys_to_ignore_on_load_unexpected = ["model.decoder.*"]
671
 
672
  def __init__(self, config: transformers.WhisperConfig):
673
  super().__init__(config)
 
681
  * self.conv2.stride[0]
682
  )
683
 
684
+ def init_latency_mask(
685
+ self, audio_latency_block_size: int | None, dtype: torch.dtype
686
+ ):
687
  if audio_latency_block_size is None:
688
  self.audio_streaming_mask = None
689
  return
 
850
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
851
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
852
 
853
+ transformers.activations.ACT2FN["swiglu"] = SwiGLU