voidful commited on
Commit
67edd60
Β·
verified Β·
1 Parent(s): 9d8c788

Update modeling_gemma3_omni.py

Browse files
Files changed (1) hide show
  1. modeling_gemma3_omni.py +48 -55
modeling_gemma3_omni.py CHANGED
@@ -1,20 +1,16 @@
1
  # -*- coding: utf-8 -*-
2
  from __future__ import annotations
3
 
4
- # ────────────────────────────────────────────────────────────────────────────────
5
- # 0. Monkey‑patch Gemma3TextScaledWordEmbedding.forward ‑‑> clone() to break view relation
6
- # ────────────────────────────────────────────────────────────────────────────────
7
  import torch
8
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding as _OrigEmb
9
 
 
10
  def _patched_forward(self: _OrigEmb, input_ids: torch.LongTensor):
11
  return super(_OrigEmb, self).forward(input_ids).clone()
12
 
 
13
  _OrigEmb.forward = _patched_forward
14
 
15
- # ────────────────────────────────────────────────────────────────────────────────
16
- # 1. Standard imports (the rest of your original file starts here) β”‚
17
- # ────────────────────────────────────────────────────────────────────────────────
18
  from typing import List, Optional, Tuple, Union, Callable
19
 
20
  from transformers import (
@@ -83,20 +79,24 @@ class Gemma3AudioProjectorConfig(PretrainedConfig):
83
 
84
  from torch import nn
85
 
 
86
  class LayerWiseWeightedSum(nn.Module):
87
  def __init__(self, num_layers: int, learnable: bool = True):
88
  super().__init__()
89
  self.num_layers = num_layers
90
  if learnable:
91
- self.scalar = nn.Parameter(torch.zeros(num_layers))
92
  else:
93
- self.register_buffer("scalar", torch.zeros(num_layers))
94
 
95
  def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
96
- assert len(hidden_states) == self.num_layers
97
- norm_w = torch.softmax(self.scalar, dim=0).view(-1, 1, 1, 1)
98
- stacked = torch.stack(hidden_states, dim=0)
99
- return (norm_w * stacked).sum(dim=0)
 
 
 
100
 
101
 
102
  class Gemma3AudioProjector(PreTrainedModel):
@@ -143,19 +143,15 @@ class Gemma3AudioProjector(PreTrainedModel):
143
  self.layer_weighter = LayerWiseWeightedSum(
144
  num_layers=encoder_config["num_blocks"]
145
  )
 
146
  self.proj = nn.Linear(encoder_config['attention_dim'], config.hidden_size, bias=False)
147
 
148
- def forward(self, mel: torch.Tensor, mel_mask: torch.Tensor):
149
- mel = mel.squeeze(1) # (B, T, 80)
150
- mel_mask = mel_mask.squeeze(1) # (B, L)
151
-
152
- if mel_mask.size(1) != mel.size(1):
153
- mel_mask = mel_mask[..., : mel.size(1)]
154
-
155
  _, out_mask, hidden_list = self.encoder(mel, mel_mask)
156
- hidden_sum = self.layer_weighter(hidden_list)
157
- hidden = self.proj(hidden_list[-1])
158
- return hidden, out_mask
 
159
 
160
 
161
  class Gemma3VisionProjector(nn.Module):
@@ -188,6 +184,7 @@ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Opti
188
 
189
  def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
190
  return token_type_ids[batch_idx, kv_idx] != 0
 
191
  return inner_mask
192
 
193
 
@@ -199,7 +196,8 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
199
  self.vision_tower = AutoModel.from_config(config=config.vision_config)
200
  self.multi_modal_projector = Gemma3VisionProjector(config)
201
  self.audio_projector = Gemma3AudioProjector(
202
- Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size)
 
203
  )
204
  self.vocab_size = config.text_config.vocab_size
205
 
@@ -235,7 +233,6 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
235
  **lm_kwargs,
236
  ) -> Union[Tuple, Gemma3ModelOutputWithPast]:
237
  if (input_ids is None) ^ (inputs_embeds is not None):
238
- print("input_ids:", input_ids, "inputs_embeds:", inputs_embeds)
239
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
240
 
241
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -261,15 +258,14 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
261
  )
262
 
263
  if pixel_values is not None and past_key_values is None:
264
- image_features = self.get_image_features(pixel_values)
 
265
 
266
  if input_ids is None:
267
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
268
- torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
269
- )
270
- else:
271
- special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
272
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
273
 
274
  if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
275
  image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
@@ -286,23 +282,14 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
286
  input_audio_embeds, audio_attention_mask
287
  )
288
  if input_ids is None:
289
- special_audio_mask = (
290
- inputs_embeds
291
- == self.get_input_embeddings()(
292
- torch.tensor(
293
- self.config.audio_token_index,
294
- dtype=torch.long,
295
- device=inputs_embeds.device,
296
- )
297
- )
298
- )
299
- else:
300
- special_audio_mask = (
301
- input_ids == self.config.audio_token_index
302
- ).unsqueeze(-1)
303
- special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(
304
- inputs_embeds.device
305
- )
306
  if (
307
  not is_torchdynamo_compiling()
308
  and inputs_embeds[special_audio_mask].numel() != audio_features.numel()
@@ -314,9 +301,9 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
314
  f"({audio_features.shape[0] * audio_features.shape[1]})."
315
  )
316
  audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
317
- inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
318
 
319
- if not isinstance(causal_mask_mapping := attention_mask, dict):
320
  mask_kwargs = {
321
  "config": self.config.get_text_config(),
322
  "input_embeds": inputs_embeds,
@@ -329,13 +316,13 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
329
  token_type_ids.to(cache_position.device)
330
  )
331
 
332
- causal_mask_mapping = {
333
  "full_attention": create_causal_mask(**mask_kwargs),
334
  "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
335
  }
336
 
337
  outputs = self.language_model(
338
- attention_mask=causal_mask_mapping,
339
  position_ids=position_ids,
340
  past_key_values=past_key_values,
341
  inputs_embeds=inputs_embeds,
@@ -347,12 +334,16 @@ class Gemma3OmniModel(Gemma3PreTrainedModel):
347
  **lm_kwargs,
348
  )
349
 
 
 
 
 
350
  return Gemma3ModelOutputWithPast(
351
  last_hidden_state=outputs.last_hidden_state,
352
  past_key_values=outputs.past_key_values if use_cache else None,
353
  hidden_states=outputs.hidden_states,
354
  attentions=outputs.attentions,
355
- image_hidden_states=image_features if pixel_values is not None else None,
356
  )
357
 
358
 
@@ -476,13 +467,15 @@ class Gemma3OmniForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin)
476
  output = (logits,) + outputs[1:]
477
  return (loss,) + output if loss is not None else output
478
 
 
 
479
  return Gemma3CausalLMOutputWithPast(
480
  loss=loss,
481
  logits=logits,
482
  past_key_values=outputs.past_key_values,
483
  hidden_states=outputs.hidden_states,
484
  attentions=outputs.attentions,
485
- image_hidden_states=outputs.image_hidden_states,
486
  )
487
 
488
 
@@ -492,4 +485,4 @@ __all__ = [
492
  "Gemma3VisionProjector",
493
  "Gemma3OmniModel",
494
  "Gemma3OmniForConditionalGeneration",
495
- ]
 
1
  # -*- coding: utf-8 -*-
2
  from __future__ import annotations
3
 
 
 
 
4
  import torch
5
  from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbedding as _OrigEmb
6
 
7
+
8
  def _patched_forward(self: _OrigEmb, input_ids: torch.LongTensor):
9
  return super(_OrigEmb, self).forward(input_ids).clone()
10
 
11
+
12
  _OrigEmb.forward = _patched_forward
13
 
 
 
 
14
  from typing import List, Optional, Tuple, Union, Callable
15
 
16
  from transformers import (
 
79
 
80
  from torch import nn
81
 
82
+
83
  class LayerWiseWeightedSum(nn.Module):
84
  def __init__(self, num_layers: int, learnable: bool = True):
85
  super().__init__()
86
  self.num_layers = num_layers
87
  if learnable:
88
+ self.scalar_weights = nn.Parameter(torch.zeros(num_layers))
89
  else:
90
+ self.register_buffer("scalar_weights", torch.zeros(num_layers))
91
 
92
  def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
93
+ if len(hidden_states) != self.num_layers:
94
+ raise ValueError(f"Expected {self.num_layers} hidden states, but got {len(hidden_states)}")
95
+
96
+ norm_weights = torch.softmax(self.scalar_weights, dim=0).view(-1, 1, 1, 1)
97
+ stacked_states = torch.stack(hidden_states, dim=0)
98
+ weighted_sum = (norm_weights * stacked_states).sum(dim=0)
99
+ return weighted_sum
100
 
101
 
102
  class Gemma3AudioProjector(PreTrainedModel):
 
143
  self.layer_weighter = LayerWiseWeightedSum(
144
  num_layers=encoder_config["num_blocks"]
145
  )
146
+ self.norm = Gemma3RMSNorm(encoder_config['attention_dim'], eps=1e-6)
147
  self.proj = nn.Linear(encoder_config['attention_dim'], config.hidden_size, bias=False)
148
 
149
+ def forward(self, mel: torch.Tensor, mel_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
150
  _, out_mask, hidden_list = self.encoder(mel, mel_mask)
151
+ features = self.layer_weighter(hidden_list)
152
+ normalized_features = self.norm(features)
153
+ projected_features = self.proj(normalized_features)
154
+ return projected_features, out_mask
155
 
156
 
157
  class Gemma3VisionProjector(nn.Module):
 
184
 
185
  def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
186
  return token_type_ids[batch_idx, kv_idx] != 0
187
+
188
  return inner_mask
189
 
190
 
 
196
  self.vision_tower = AutoModel.from_config(config=config.vision_config)
197
  self.multi_modal_projector = Gemma3VisionProjector(config)
198
  self.audio_projector = Gemma3AudioProjector(
199
+ Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size, n_mels=config.audio_config.n_mels,
200
+ num_hidden_layers=config.audio_config.num_hidden_layers)
201
  )
202
  self.vocab_size = config.text_config.vocab_size
203
 
 
233
  **lm_kwargs,
234
  ) -> Union[Tuple, Gemma3ModelOutputWithPast]:
235
  if (input_ids is None) ^ (inputs_embeds is not None):
 
236
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
237
 
238
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
258
  )
259
 
260
  if pixel_values is not None and past_key_values is None:
261
+ vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
262
+ image_features = self.multi_modal_projector(vision_outputs.hidden_states[-1])
263
 
264
  if input_ids is None:
265
+ raise ValueError("`input_ids` are required when `pixel_values` are provided.")
266
+
267
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
268
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
 
 
269
 
270
  if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
271
  image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
 
282
  input_audio_embeds, audio_attention_mask
283
  )
284
  if input_ids is None:
285
+ raise ValueError("`input_ids` are required when `input_audio_embeds` are provided.")
286
+
287
+ special_audio_mask = (
288
+ input_ids == self.config.audio_token_index
289
+ ).unsqueeze(-1)
290
+ special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(
291
+ inputs_embeds.device
292
+ )
 
 
 
 
 
 
 
 
 
293
  if (
294
  not is_torchdynamo_compiling()
295
  and inputs_embeds[special_audio_mask].numel() != audio_features.numel()
 
301
  f"({audio_features.shape[0] * audio_features.shape[1]})."
302
  )
303
  audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
304
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features.flatten(0, 1))
305
 
306
+ if not isinstance(attention_mask, dict):
307
  mask_kwargs = {
308
  "config": self.config.get_text_config(),
309
  "input_embeds": inputs_embeds,
 
316
  token_type_ids.to(cache_position.device)
317
  )
318
 
319
+ attention_mask = {
320
  "full_attention": create_causal_mask(**mask_kwargs),
321
  "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
322
  }
323
 
324
  outputs = self.language_model(
325
+ attention_mask=attention_mask,
326
  position_ids=position_ids,
327
  past_key_values=past_key_values,
328
  inputs_embeds=inputs_embeds,
 
334
  **lm_kwargs,
335
  )
336
 
337
+ image_hidden_states = None
338
+ if 'vision_outputs' in locals():
339
+ image_hidden_states = vision_outputs.hidden_states[-1]
340
+
341
  return Gemma3ModelOutputWithPast(
342
  last_hidden_state=outputs.last_hidden_state,
343
  past_key_values=outputs.past_key_values if use_cache else None,
344
  hidden_states=outputs.hidden_states,
345
  attentions=outputs.attentions,
346
+ image_hidden_states=image_hidden_states,
347
  )
348
 
349
 
 
467
  output = (logits,) + outputs[1:]
468
  return (loss,) + output if loss is not None else output
469
 
470
+ image_hidden_states = outputs.image_hidden_states if return_dict else outputs[4]
471
+
472
  return Gemma3CausalLMOutputWithPast(
473
  loss=loss,
474
  logits=logits,
475
  past_key_values=outputs.past_key_values,
476
  hidden_states=outputs.hidden_states,
477
  attentions=outputs.attentions,
478
+ image_hidden_states=image_hidden_states,
479
  )
480
 
481
 
 
485
  "Gemma3VisionProjector",
486
  "Gemma3OmniModel",
487
  "Gemma3OmniForConditionalGeneration",
488
+ ]