Update modeling_gemma3_omni.py
Browse files- modeling_gemma3_omni.py +48 -55
modeling_gemma3_omni.py
CHANGED
@@ -1,20 +1,16 @@
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
from __future__ import annotations
|
3 |
|
4 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
5 |
-
# 0. Monkeyβpatch Gemma3TextScaledWordEmbedding.forward ββ> clone() to break view relation
|
6 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
7 |
import torch
|
8 |
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding as _OrigEmb
|
9 |
|
|
|
10 |
def _patched_forward(self: _OrigEmb, input_ids: torch.LongTensor):
|
11 |
return super(_OrigEmb, self).forward(input_ids).clone()
|
12 |
|
|
|
13 |
_OrigEmb.forward = _patched_forward
|
14 |
|
15 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
16 |
-
# 1. Standard imports (the rest of your original file starts here) β
|
17 |
-
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
18 |
from typing import List, Optional, Tuple, Union, Callable
|
19 |
|
20 |
from transformers import (
|
@@ -83,20 +79,24 @@ class Gemma3AudioProjectorConfig(PretrainedConfig):
|
|
83 |
|
84 |
from torch import nn
|
85 |
|
|
|
86 |
class LayerWiseWeightedSum(nn.Module):
|
87 |
def __init__(self, num_layers: int, learnable: bool = True):
|
88 |
super().__init__()
|
89 |
self.num_layers = num_layers
|
90 |
if learnable:
|
91 |
-
self.
|
92 |
else:
|
93 |
-
self.register_buffer("
|
94 |
|
95 |
def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
100 |
|
101 |
|
102 |
class Gemma3AudioProjector(PreTrainedModel):
|
@@ -143,19 +143,15 @@ class Gemma3AudioProjector(PreTrainedModel):
|
|
143 |
self.layer_weighter = LayerWiseWeightedSum(
|
144 |
num_layers=encoder_config["num_blocks"]
|
145 |
)
|
|
|
146 |
self.proj = nn.Linear(encoder_config['attention_dim'], config.hidden_size, bias=False)
|
147 |
|
148 |
-
def forward(self, mel: torch.Tensor, mel_mask: torch.Tensor):
|
149 |
-
mel = mel.squeeze(1) # (B, T, 80)
|
150 |
-
mel_mask = mel_mask.squeeze(1) # (B, L)
|
151 |
-
|
152 |
-
if mel_mask.size(1) != mel.size(1):
|
153 |
-
mel_mask = mel_mask[..., : mel.size(1)]
|
154 |
-
|
155 |
_, out_mask, hidden_list = self.encoder(mel, mel_mask)
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
159 |
|
160 |
|
161 |
class Gemma3VisionProjector(nn.Module):
|
@@ -188,6 +184,7 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
|
|
188 |
|
189 |
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
190 |
return token_type_ids[batch_idx, kv_idx] != 0
|
|
|
191 |
return inner_mask
|
192 |
|
193 |
|
@@ -199,7 +196,8 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
|
|
199 |
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
200 |
self.multi_modal_projector = Gemma3VisionProjector(config)
|
201 |
self.audio_projector = Gemma3AudioProjector(
|
202 |
-
Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size
|
|
|
203 |
)
|
204 |
self.vocab_size = config.text_config.vocab_size
|
205 |
|
@@ -235,7 +233,6 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
|
|
235 |
**lm_kwargs,
|
236 |
) -> Union[Tuple, Gemma3ModelOutputWithPast]:
|
237 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
238 |
-
print("input_ids:", input_ids, "inputs_embeds:", inputs_embeds)
|
239 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
240 |
|
241 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
@@ -261,15 +258,14 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
|
|
261 |
)
|
262 |
|
263 |
if pixel_values is not None and past_key_values is None:
|
264 |
-
|
|
|
265 |
|
266 |
if input_ids is None:
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
272 |
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
273 |
|
274 |
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
275 |
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
@@ -286,23 +282,14 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
|
|
286 |
input_audio_embeds, audio_attention_mask
|
287 |
)
|
288 |
if input_ids is None:
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
)
|
298 |
-
)
|
299 |
-
else:
|
300 |
-
special_audio_mask = (
|
301 |
-
input_ids == self.config.audio_token_index
|
302 |
-
).unsqueeze(-1)
|
303 |
-
special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(
|
304 |
-
inputs_embeds.device
|
305 |
-
)
|
306 |
if (
|
307 |
not is_torchdynamo_compiling()
|
308 |
and inputs_embeds[special_audio_mask].numel() != audio_features.numel()
|
@@ -314,9 +301,9 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
|
|
314 |
f"({audio_features.shape[0] * audio_features.shape[1]})."
|
315 |
)
|
316 |
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
317 |
-
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
|
318 |
|
319 |
-
if not isinstance(
|
320 |
mask_kwargs = {
|
321 |
"config": self.config.get_text_config(),
|
322 |
"input_embeds": inputs_embeds,
|
@@ -329,13 +316,13 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
|
|
329 |
token_type_ids.to(cache_position.device)
|
330 |
)
|
331 |
|
332 |
-
|
333 |
"full_attention": create_causal_mask(**mask_kwargs),
|
334 |
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
335 |
}
|
336 |
|
337 |
outputs = self.language_model(
|
338 |
-
attention_mask=
|
339 |
position_ids=position_ids,
|
340 |
past_key_values=past_key_values,
|
341 |
inputs_embeds=inputs_embeds,
|
@@ -347,12 +334,16 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
|
|
347 |
**lm_kwargs,
|
348 |
)
|
349 |
|
|
|
|
|
|
|
|
|
350 |
return Gemma3ModelOutputWithPast(
|
351 |
last_hidden_state=outputs.last_hidden_state,
|
352 |
past_key_values=outputs.past_key_values if use_cache else None,
|
353 |
hidden_states=outputs.hidden_states,
|
354 |
attentions=outputs.attentions,
|
355 |
-
image_hidden_states=
|
356 |
)
|
357 |
|
358 |
|
@@ -476,13 +467,15 @@ class Gemma3OmniForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin)
|
|
476 |
output = (logits,) + outputs[1:]
|
477 |
return (loss,) + output if loss is not None else output
|
478 |
|
|
|
|
|
479 |
return Gemma3CausalLMOutputWithPast(
|
480 |
loss=loss,
|
481 |
logits=logits,
|
482 |
past_key_values=outputs.past_key_values,
|
483 |
hidden_states=outputs.hidden_states,
|
484 |
attentions=outputs.attentions,
|
485 |
-
image_hidden_states=
|
486 |
)
|
487 |
|
488 |
|
@@ -492,4 +485,4 @@ __all__ = [
|
|
492 |
"Gemma3VisionProjector",
|
493 |
"Gemma3OmniModel",
|
494 |
"Gemma3OmniForConditionalGeneration",
|
495 |
-
]
|
|
|
1 |
# -*- coding: utf-8 -*-
|
2 |
from __future__ import annotations
|
3 |
|
|
|
|
|
|
|
4 |
import torch
|
5 |
from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding as _OrigEmb
|
6 |
|
7 |
+
|
8 |
def _patched_forward(self: _OrigEmb, input_ids: torch.LongTensor):
|
9 |
return super(_OrigEmb, self).forward(input_ids).clone()
|
10 |
|
11 |
+
|
12 |
_OrigEmb.forward = _patched_forward
|
13 |
|
|
|
|
|
|
|
14 |
from typing import List, Optional, Tuple, Union, Callable
|
15 |
|
16 |
from transformers import (
|
|
|
79 |
|
80 |
from torch import nn
|
81 |
|
82 |
+
|
83 |
class LayerWiseWeightedSum(nn.Module):
|
84 |
def __init__(self, num_layers: int, learnable: bool = True):
|
85 |
super().__init__()
|
86 |
self.num_layers = num_layers
|
87 |
if learnable:
|
88 |
+
self.scalar_weights = nn.Parameter(torch.zeros(num_layers))
|
89 |
else:
|
90 |
+
self.register_buffer("scalar_weights", torch.zeros(num_layers))
|
91 |
|
92 |
def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
|
93 |
+
if len(hidden_states) != self.num_layers:
|
94 |
+
raise ValueError(f"Expected {self.num_layers} hidden states, but got {len(hidden_states)}")
|
95 |
+
|
96 |
+
norm_weights = torch.softmax(self.scalar_weights, dim=0).view(-1, 1, 1, 1)
|
97 |
+
stacked_states = torch.stack(hidden_states, dim=0)
|
98 |
+
weighted_sum = (norm_weights * stacked_states).sum(dim=0)
|
99 |
+
return weighted_sum
|
100 |
|
101 |
|
102 |
class Gemma3AudioProjector(PreTrainedModel):
|
|
|
143 |
self.layer_weighter = LayerWiseWeightedSum(
|
144 |
num_layers=encoder_config["num_blocks"]
|
145 |
)
|
146 |
+
self.norm = Gemma3RMSNorm(encoder_config['attention_dim'], eps=1e-6)
|
147 |
self.proj = nn.Linear(encoder_config['attention_dim'], config.hidden_size, bias=False)
|
148 |
|
149 |
+
def forward(self, mel: torch.Tensor, mel_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
_, out_mask, hidden_list = self.encoder(mel, mel_mask)
|
151 |
+
features = self.layer_weighter(hidden_list)
|
152 |
+
normalized_features = self.norm(features)
|
153 |
+
projected_features = self.proj(normalized_features)
|
154 |
+
return projected_features, out_mask
|
155 |
|
156 |
|
157 |
class Gemma3VisionProjector(nn.Module):
|
|
|
184 |
|
185 |
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
186 |
return token_type_ids[batch_idx, kv_idx] != 0
|
187 |
+
|
188 |
return inner_mask
|
189 |
|
190 |
|
|
|
196 |
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
197 |
self.multi_modal_projector = Gemma3VisionProjector(config)
|
198 |
self.audio_projector = Gemma3AudioProjector(
|
199 |
+
Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size, n_mels=config.audio_config.n_mels,
|
200 |
+
num_hidden_layers=config.audio_config.num_hidden_layers)
|
201 |
)
|
202 |
self.vocab_size = config.text_config.vocab_size
|
203 |
|
|
|
233 |
**lm_kwargs,
|
234 |
) -> Union[Tuple, Gemma3ModelOutputWithPast]:
|
235 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
|
236 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
237 |
|
238 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
258 |
)
|
259 |
|
260 |
if pixel_values is not None and past_key_values is None:
|
261 |
+
vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
262 |
+
image_features = self.multi_modal_projector(vision_outputs.hidden_states[-1])
|
263 |
|
264 |
if input_ids is None:
|
265 |
+
raise ValueError("`input_ids` are required when `pixel_values` are provided.")
|
266 |
+
|
267 |
+
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
268 |
+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
|
|
|
269 |
|
270 |
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
271 |
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
|
|
282 |
input_audio_embeds, audio_attention_mask
|
283 |
)
|
284 |
if input_ids is None:
|
285 |
+
raise ValueError("`input_ids` are required when `input_audio_embeds` are provided.")
|
286 |
+
|
287 |
+
special_audio_mask = (
|
288 |
+
input_ids == self.config.audio_token_index
|
289 |
+
).unsqueeze(-1)
|
290 |
+
special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(
|
291 |
+
inputs_embeds.device
|
292 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
if (
|
294 |
not is_torchdynamo_compiling()
|
295 |
and inputs_embeds[special_audio_mask].numel() != audio_features.numel()
|
|
|
301 |
f"({audio_features.shape[0] * audio_features.shape[1]})."
|
302 |
)
|
303 |
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
304 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features.flatten(0, 1))
|
305 |
|
306 |
+
if not isinstance(attention_mask, dict):
|
307 |
mask_kwargs = {
|
308 |
"config": self.config.get_text_config(),
|
309 |
"input_embeds": inputs_embeds,
|
|
|
316 |
token_type_ids.to(cache_position.device)
|
317 |
)
|
318 |
|
319 |
+
attention_mask = {
|
320 |
"full_attention": create_causal_mask(**mask_kwargs),
|
321 |
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
322 |
}
|
323 |
|
324 |
outputs = self.language_model(
|
325 |
+
attention_mask=attention_mask,
|
326 |
position_ids=position_ids,
|
327 |
past_key_values=past_key_values,
|
328 |
inputs_embeds=inputs_embeds,
|
|
|
334 |
**lm_kwargs,
|
335 |
)
|
336 |
|
337 |
+
image_hidden_states = None
|
338 |
+
if 'vision_outputs' in locals():
|
339 |
+
image_hidden_states = vision_outputs.hidden_states[-1]
|
340 |
+
|
341 |
return Gemma3ModelOutputWithPast(
|
342 |
last_hidden_state=outputs.last_hidden_state,
|
343 |
past_key_values=outputs.past_key_values if use_cache else None,
|
344 |
hidden_states=outputs.hidden_states,
|
345 |
attentions=outputs.attentions,
|
346 |
+
image_hidden_states=image_hidden_states,
|
347 |
)
|
348 |
|
349 |
|
|
|
467 |
output = (logits,) + outputs[1:]
|
468 |
return (loss,) + output if loss is not None else output
|
469 |
|
470 |
+
image_hidden_states = outputs.image_hidden_states if return_dict else outputs[4]
|
471 |
+
|
472 |
return Gemma3CausalLMOutputWithPast(
|
473 |
loss=loss,
|
474 |
logits=logits,
|
475 |
past_key_values=outputs.past_key_values,
|
476 |
hidden_states=outputs.hidden_states,
|
477 |
attentions=outputs.attentions,
|
478 |
+
image_hidden_states=image_hidden_states,
|
479 |
)
|
480 |
|
481 |
|
|
|
485 |
"Gemma3VisionProjector",
|
486 |
"Gemma3OmniModel",
|
487 |
"Gemma3OmniForConditionalGeneration",
|
488 |
+
]
|