Update ultravox_model.py
Browse files- ultravox_model.py +112 -43
ultravox_model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import logging
|
2 |
import re
|
3 |
-
from typing import Any, Dict, Generator, Optional, Set, Tuple, Union
|
4 |
|
5 |
import peft
|
6 |
import torch
|
@@ -56,6 +56,11 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
56 |
self.multi_modal_projector = self._create_multi_modal_projector(config)
|
57 |
self.language_model = self._create_language_model(config)
|
58 |
|
|
|
|
|
|
|
|
|
|
|
59 |
# Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
|
60 |
# FSDP throws an error if some of the layer types are not found in the model.
|
61 |
# This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
|
@@ -64,6 +69,39 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
64 |
self.loss_config = LossConfig()
|
65 |
self.post_init()
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def get_input_embeddings(self):
|
68 |
return self.language_model.get_input_embeddings()
|
69 |
|
@@ -110,6 +148,30 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
110 |
self.vocab_size = model_embeds.num_embeddings
|
111 |
return model_embeds
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
def _compute_kl_loss(
|
114 |
self,
|
115 |
lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
|
@@ -134,11 +196,12 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
134 |
# compute the KL divergence loss between the two models
|
135 |
kl_loss = F.kl_div(
|
136 |
F.log_softmax(
|
137 |
-
lm_output.logits[labels
|
|
|
138 |
dim=-1,
|
139 |
),
|
140 |
F.softmax(
|
141 |
-
alt_lm_output.logits[alt_labels
|
142 |
/ self.loss_config.kl_temperature,
|
143 |
dim=-1,
|
144 |
),
|
@@ -289,7 +352,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
289 |
|
290 |
# include audio information in model_input only when it is needed during prefilling
|
291 |
# audio_token_start_idx should always be relative to the current cache position
|
292 |
-
prefill_start_idx
|
|
|
|
|
293 |
if (
|
294 |
audio_values is not None
|
295 |
and audio_token_start_idx is not None
|
@@ -317,23 +382,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
317 |
def _create_audio_tower(
|
318 |
cls, config: UltravoxConfig
|
319 |
) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
|
320 |
-
|
321 |
-
if
|
322 |
-
|
323 |
-
config.audio_model_id, torch_dtype=config.torch_dtype
|
324 |
-
)
|
325 |
-
audio_tower.init_latency_mask(
|
326 |
-
config.audio_latency_block_size, dtype=config.torch_dtype
|
327 |
-
)
|
328 |
-
else:
|
329 |
-
assert config.audio_latency_block_size in (
|
330 |
-
None,
|
331 |
-
0,
|
332 |
-
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
|
333 |
-
audio_tower = transformers.AutoModel.from_pretrained(
|
334 |
-
config.audio_model_id, torch_dtype=config.torch_dtype
|
335 |
-
)
|
336 |
-
else:
|
337 |
if "whisper" in config.audio_config._name_or_path.lower():
|
338 |
audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
339 |
audio_tower.init_latency_mask(
|
@@ -344,12 +395,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
344 |
None,
|
345 |
0,
|
346 |
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
|
347 |
-
|
348 |
-
# we only ever use from_config if the weights are retrained, hence initializing is not
|
349 |
-
# required. This makes the model quite creation faster since init on CPU is quite slow.
|
350 |
-
audio_tower = transformers.AutoModel.from_config(
|
351 |
-
config.audio_config
|
352 |
-
)
|
353 |
|
354 |
if isinstance(
|
355 |
audio_tower,
|
@@ -367,21 +413,14 @@ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
|
|
367 |
def _create_language_model(
|
368 |
cls, config: UltravoxConfig
|
369 |
) -> transformers.LlamaForCausalLM:
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
374 |
torch_dtype=config.torch_dtype,
|
375 |
)
|
376 |
-
else:
|
377 |
-
with transformers.modeling_utils.no_init_weights():
|
378 |
-
# we only ever use from_config if the weights are retrained, hence initializing is not
|
379 |
-
# required. This makes the model quite creation faster since init on CPU is quite slow.
|
380 |
-
language_model = transformers.AutoModelForCausalLM.from_config(
|
381 |
-
config.text_config,
|
382 |
-
attn_implementation=config._attn_implementation,
|
383 |
-
torch_dtype=config.torch_dtype,
|
384 |
-
)
|
385 |
|
386 |
language_model = apply_lora(language_model, config.text_model_lora_config)
|
387 |
return language_model
|
@@ -495,7 +534,10 @@ def is_cache_empty(
|
|
495 |
return past_key_values.get_seq_length() == 0
|
496 |
|
497 |
|
498 |
-
|
|
|
|
|
|
|
499 |
"""
|
500 |
Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
|
501 |
"""
|
@@ -574,11 +616,35 @@ class UltravoxProjector(nn.Module):
|
|
574 |
self.ln_post = RMSNorm(dim_out, init=config.norm_init)
|
575 |
|
576 |
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
audio_features = self._pad_and_stack(audio_features)
|
578 |
audio_features = self.ln_pre(audio_features)
|
|
|
579 |
hidden_states = self.linear_1(audio_features)
|
|
|
580 |
hidden_states = self.act(hidden_states)
|
581 |
hidden_states = self.ln_mid(hidden_states)
|
|
|
582 |
hidden_states = self.linear_2(hidden_states)
|
583 |
hidden_states = self.ln_post(hidden_states)
|
584 |
return hidden_states
|
@@ -601,6 +667,7 @@ class ModifiedWhisperEncoder(
|
|
601 |
|
602 |
base_model_prefix = "model.encoder"
|
603 |
_no_split_modules = ["WhisperEncoderLayer"]
|
|
|
604 |
|
605 |
def __init__(self, config: transformers.WhisperConfig):
|
606 |
super().__init__(config)
|
@@ -614,7 +681,9 @@ class ModifiedWhisperEncoder(
|
|
614 |
* self.conv2.stride[0]
|
615 |
)
|
616 |
|
617 |
-
def init_latency_mask(
|
|
|
|
|
618 |
if audio_latency_block_size is None:
|
619 |
self.audio_streaming_mask = None
|
620 |
return
|
@@ -781,4 +850,4 @@ UltravoxModel.register_for_auto_class()
|
|
781 |
transformers.AutoConfig.register("ultravox", UltravoxConfig)
|
782 |
transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
|
783 |
|
784 |
-
transformers.activations.ACT2FN["swiglu"] = SwiGLU
|
|
|
1 |
import logging
|
2 |
import re
|
3 |
+
from typing import Any, Dict, Generator, Optional, Set, Tuple, TypeVar, Union
|
4 |
|
5 |
import peft
|
6 |
import torch
|
|
|
56 |
self.multi_modal_projector = self._create_multi_modal_projector(config)
|
57 |
self.language_model = self._create_language_model(config)
|
58 |
|
59 |
+
if self.language_model._tied_weights_keys is not None:
|
60 |
+
self._tied_weights_keys = [
|
61 |
+
f"language_model.{k}" for k in self.language_model._tied_weights_keys
|
62 |
+
]
|
63 |
+
|
64 |
# Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
|
65 |
# FSDP throws an error if some of the layer types are not found in the model.
|
66 |
# This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
|
|
|
69 |
self.loss_config = LossConfig()
|
70 |
self.post_init()
|
71 |
|
72 |
+
@classmethod
|
73 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
74 |
+
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
|
75 |
+
model._load_child_model_weights(*args, **kwargs)
|
76 |
+
return model
|
77 |
+
|
78 |
+
def _load_child_model_weights(self, *args, **kwargs) -> "UltravoxModel":
|
79 |
+
if (
|
80 |
+
self.config.text_model_id is not None
|
81 |
+
and self.language_model.device.type == "meta"
|
82 |
+
):
|
83 |
+
# Load the language model weights
|
84 |
+
self.language_model = transformers.AutoModelForCausalLM.from_pretrained(
|
85 |
+
self.config.text_model_id,
|
86 |
+
torch_dtype=self.config.torch_dtype,
|
87 |
+
*args,
|
88 |
+
**kwargs,
|
89 |
+
)
|
90 |
+
|
91 |
+
if (
|
92 |
+
self.config.audio_model_id is not None
|
93 |
+
and self.audio_tower.device.type == "meta"
|
94 |
+
):
|
95 |
+
# Load the audio tower weights
|
96 |
+
self.audio_tower = transformers.AutoModel.from_pretrained(
|
97 |
+
self.config.audio_model_id,
|
98 |
+
torch_dtype=self.config.torch_dtype,
|
99 |
+
*args,
|
100 |
+
**kwargs,
|
101 |
+
)
|
102 |
+
|
103 |
+
return self
|
104 |
+
|
105 |
def get_input_embeddings(self):
|
106 |
return self.language_model.get_input_embeddings()
|
107 |
|
|
|
148 |
self.vocab_size = model_embeds.num_embeddings
|
149 |
return model_embeds
|
150 |
|
151 |
+
def _get_prediction_mask(self, labels: Optional[torch.Tensor]) -> torch.Tensor:
|
152 |
+
"""Get a boolean mask for positions where we want to compute KL divergence.
|
153 |
+
|
154 |
+
For each label position, we want the position before it since that's where
|
155 |
+
the model makes the prediction for that label.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
labels: Tensor of shape (B, T) where B is batch size and T is sequence length,
|
159 |
+
with -100 for masked positions and token ids for label positions
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Boolean tensor of shape (B, T) that's True for positions where we want to compute KL divergence
|
163 |
+
"""
|
164 |
+
if labels is None:
|
165 |
+
raise ValueError("labels must be provided")
|
166 |
+
# Shift the label mask right by 1 along the sequence dimension
|
167 |
+
# This gives us positions where we make predictions for the next token
|
168 |
+
label_mask = labels != -100
|
169 |
+
pred_mask = torch.zeros_like(label_mask)
|
170 |
+
pred_mask[:, :-1] = label_mask[
|
171 |
+
:, 1:
|
172 |
+
] # shift right by 1 along sequence dimension
|
173 |
+
return pred_mask
|
174 |
+
|
175 |
def _compute_kl_loss(
|
176 |
self,
|
177 |
lm_output: transformers.modeling_outputs.CausalLMOutputWithPast,
|
|
|
196 |
# compute the KL divergence loss between the two models
|
197 |
kl_loss = F.kl_div(
|
198 |
F.log_softmax(
|
199 |
+
lm_output.logits[self._get_prediction_mask(labels)]
|
200 |
+
/ self.loss_config.kl_temperature,
|
201 |
dim=-1,
|
202 |
),
|
203 |
F.softmax(
|
204 |
+
alt_lm_output.logits[self._get_prediction_mask(alt_labels)]
|
205 |
/ self.loss_config.kl_temperature,
|
206 |
dim=-1,
|
207 |
),
|
|
|
352 |
|
353 |
# include audio information in model_input only when it is needed during prefilling
|
354 |
# audio_token_start_idx should always be relative to the current cache position
|
355 |
+
prefill_start_idx: int | torch.Tensor = (
|
356 |
+
0 if cache_position is None else cache_position[0]
|
357 |
+
)
|
358 |
if (
|
359 |
audio_values is not None
|
360 |
and audio_token_start_idx is not None
|
|
|
382 |
def _create_audio_tower(
|
383 |
cls, config: UltravoxConfig
|
384 |
) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
|
385 |
+
with transformers.modeling_utils.no_init_weights():
|
386 |
+
# we only ever use from_config if the weights are retrained, hence initializing is not
|
387 |
+
# required. This makes the model quite creation faster since init on CPU is quite slow.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
if "whisper" in config.audio_config._name_or_path.lower():
|
389 |
audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
390 |
audio_tower.init_latency_mask(
|
|
|
395 |
None,
|
396 |
0,
|
397 |
), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
|
398 |
+
audio_tower = transformers.AutoModel.from_config(config.audio_config)
|
|
|
|
|
|
|
|
|
|
|
399 |
|
400 |
if isinstance(
|
401 |
audio_tower,
|
|
|
413 |
def _create_language_model(
|
414 |
cls, config: UltravoxConfig
|
415 |
) -> transformers.LlamaForCausalLM:
|
416 |
+
with transformers.modeling_utils.no_init_weights():
|
417 |
+
# we only ever use from_config if the weights are retrained, hence initializing is not
|
418 |
+
# required. This makes the model quite creation faster since init on CPU is quite slow.
|
419 |
+
language_model = transformers.AutoModelForCausalLM.from_config(
|
420 |
+
config.text_config,
|
421 |
+
attn_implementation=config.text_config._attn_implementation,
|
422 |
torch_dtype=config.torch_dtype,
|
423 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
language_model = apply_lora(language_model, config.text_model_lora_config)
|
426 |
return language_model
|
|
|
534 |
return past_key_values.get_seq_length() == 0
|
535 |
|
536 |
|
537 |
+
T = TypeVar("T", bound=torch.nn.Module)
|
538 |
+
|
539 |
+
|
540 |
+
def apply_lora(model: T, lora_config: dict) -> T:
|
541 |
"""
|
542 |
Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
|
543 |
"""
|
|
|
616 |
self.ln_post = RMSNorm(dim_out, init=config.norm_init)
|
617 |
|
618 |
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
619 |
+
"""
|
620 |
+
Takes in audio features from the audio tower and projects them to the text model's embedding space.
|
621 |
+
It reduces the number of frames by a factor of `stack_factor` and increases the number of channels by the same factor.
|
622 |
+
If the number of audio frames are not a multiple of the stack factor, the last few frames will be padded with zeros.
|
623 |
+
|
624 |
+
Input shape:
|
625 |
+
audio_features: B, T*S, C
|
626 |
+
Output shape:
|
627 |
+
hidden_states: B, T, D
|
628 |
+
Where:
|
629 |
+
B: batch size
|
630 |
+
F: number of frames in the audio tower
|
631 |
+
T: number of output embeddings
|
632 |
+
T = ceil(F / S)
|
633 |
+
S: stack factor
|
634 |
+
C: number of channels out of the encoder (aka audio tower)
|
635 |
+
H: hidden size of the projector (config.hidden_size)
|
636 |
+
D: dimension of the text model (config.text_config.hidden_size)
|
637 |
+
|
638 |
+
"""
|
639 |
+
# B, F, C -> B, T, C*S
|
640 |
audio_features = self._pad_and_stack(audio_features)
|
641 |
audio_features = self.ln_pre(audio_features)
|
642 |
+
# B, T, C*S -> B, T, H
|
643 |
hidden_states = self.linear_1(audio_features)
|
644 |
+
# B, T, H -> B, T, H/2 (assuming swiglu)
|
645 |
hidden_states = self.act(hidden_states)
|
646 |
hidden_states = self.ln_mid(hidden_states)
|
647 |
+
# B, T, H/2 -> B, T, D
|
648 |
hidden_states = self.linear_2(hidden_states)
|
649 |
hidden_states = self.ln_post(hidden_states)
|
650 |
return hidden_states
|
|
|
667 |
|
668 |
base_model_prefix = "model.encoder"
|
669 |
_no_split_modules = ["WhisperEncoderLayer"]
|
670 |
+
_keys_to_ignore_on_load_unexpected = ["model.decoder.*"]
|
671 |
|
672 |
def __init__(self, config: transformers.WhisperConfig):
|
673 |
super().__init__(config)
|
|
|
681 |
* self.conv2.stride[0]
|
682 |
)
|
683 |
|
684 |
+
def init_latency_mask(
|
685 |
+
self, audio_latency_block_size: int | None, dtype: torch.dtype
|
686 |
+
):
|
687 |
if audio_latency_block_size is None:
|
688 |
self.audio_streaming_mask = None
|
689 |
return
|
|
|
850 |
transformers.AutoConfig.register("ultravox", UltravoxConfig)
|
851 |
transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
|
852 |
|
853 |
+
transformers.activations.ACT2FN["swiglu"] = SwiGLU
|