earica-audio-1b / modeling_gemma3_omni.py
voidful's picture
Update modeling_gemma3_omni.py
ffd5614 verified
# -*- coding: utf-8 -*-
"""
Gemma3Omni – Whisper v3 多模態架構完整實作
--------------------------------------------------
包含音訊(Whisper v3 projector)、視覺、主模型與生成模型。
"""
from __future__ import annotations
import torch
from torch import nn
from typing import List, Optional, Tuple, Union, Callable
from transformers import (
AutoModel,
AutoProcessor,
WhisperModel,
PreTrainedModel,
PretrainedConfig,
Cache,
)
from transformers.generation import GenerationMixin
from transformers.models.gemma3.modeling_gemma3 import (
Gemma3TextScaledWordEmbedding as _OrigEmb,
Gemma3RMSNorm,
Gemma3PreTrainedModel,
Gemma3ModelOutputWithPast,
Gemma3CausalLMOutputWithPast,
)
from transformers.utils import (
is_torchdynamo_compiling,
logging,
is_torch_flex_attn_available,
)
from .configuration_gemma3_omni import Gemma3OmniConfig
logger = logging.get_logger(__name__)
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask
# --------------------------------------------------------------------
# 1. Audio 投影器 config & projector(Whisper v3 實作)
# --------------------------------------------------------------------
class Gemma3AudioProjectorConfig(PretrainedConfig):
model_type = "gemma3_audio_whisper"
def __init__(
self,
hidden_size: int = 2304,
encoder_model_name: str = "openai/whisper-small",
projector_num_layers: int = 3,
freeze_whisper: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.encoder_model_name = encoder_model_name
self.projector_num_layers = projector_num_layers
self.freeze_whisper = freeze_whisper
class TransformerProjector(nn.Module):
def __init__(self, d_model: int, n_layers: int, n_head: int, dim_ff: int, out_dim: int):
super().__init__()
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=n_head,
dim_feedforward=dim_ff,
dropout=0.1,
activation="gelu",
batch_first=True,
)
self.backbone = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
self.norm = nn.LayerNorm(d_model)
self.out = nn.Linear(d_model, out_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, T, d_model)
x = self.backbone(x)
x = self.norm(x)
return self.out(x) # (B, T, out_dim)
class Gemma3AudioProjector(PreTrainedModel):
config_class = Gemma3AudioProjectorConfig
base_model_prefix = "audio_projector"
def __init__(self, config: Gemma3AudioProjectorConfig):
super().__init__(config)
self.processor = AutoProcessor.from_pretrained(config.encoder_model_name)
whisper_full = WhisperModel.from_pretrained(config.encoder_model_name)
self.whisper_encoder = whisper_full.get_encoder()
if config.freeze_whisper:
self.whisper_encoder.eval()
for p in self.whisper_encoder.parameters():
p.requires_grad = False
d_model = whisper_full.config.d_model
n_head = whisper_full.config.encoder_attention_heads
dim_ff = whisper_full.config.encoder_ffn_dim
self.projector = TransformerProjector(
d_model=d_model,
n_layers=config.projector_num_layers,
n_head=n_head,
dim_ff=dim_ff,
out_dim=config.hidden_size,
)
@torch.no_grad()
def _encode_whisper(self, input_features: torch.Tensor, attention_mask: Optional[torch.Tensor]):
return self.whisper_encoder(
input_features=input_features,
attention_mask=attention_mask,
return_dict=True,
).last_hidden_state
def forward(
self,
input_features: torch.Tensor, # (B, 128, 3000)
attention_mask: Optional[torch.Tensor] = None, # (B, 3000)
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
hidden = self._encode_whisper(input_features, attention_mask)
projected = self.projector(hidden)
out_mask = (
attention_mask
if attention_mask is not None
else torch.ones(projected.shape[:-1], dtype=torch.bool, device=projected.device)
)
return projected, out_mask
# --------------------------------------------------------------------
# 2. Vision 投影器(原版保留)
# --------------------------------------------------------------------
class Gemma3VisionProjector(nn.Module):
def __init__(self, config):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
)
self.mm_soft_emb_norm = Gemma3RMSNorm(
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
)
self.patches_per_image = config.vision_config.image_size // config.vision_config.patch_size
self.tokens_per_side = int(config.mm_tokens_per_image ** 0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
def forward(self, vision_outputs: torch.Tensor):
b, _, seq_len = vision_outputs.shape
x = vision_outputs.transpose(1, 2).reshape(
b, seq_len, self.patches_per_image, self.patches_per_image
)
x = self.avg_pool(x).flatten(2).transpose(1, 2)
x = self.mm_soft_emb_norm(x)
return torch.matmul(x, self.mm_input_projection_weight).type_as(vision_outputs)
# --------------------------------------------------------------------
# 3. token type mask utility(原版保留)
# --------------------------------------------------------------------
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
if token_type_ids is None:
return None
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
return token_type_ids[batch_idx, kv_idx] != 0
return inner_mask
# --------------------------------------------------------------------
# 4. 主體多模態模型
# --------------------------------------------------------------------
class Gemma3OmniModel(Gemma3PreTrainedModel):
config_class = None # 請在外部傳入正確 config 類型(如 Gemma3OmniConfig)
def __init__(self, config):
super().__init__(config)
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.multi_modal_projector = Gemma3VisionProjector(config)
self.audio_projector = Gemma3AudioProjector(
Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size)
)
self.vocab_size = config.text_config.vocab_size
language_model = AutoModel.from_config(config=config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
input_audio_embeds: Optional[torch.FloatTensor] = None,
audio_attention_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**lm_kwargs,
) -> Union[Tuple, Gemma3ModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
special_image_mask = input_ids == self.config.image_token_id
llm_input_ids = input_ids.clone()
llm_input_ids[special_image_mask] = 0
else:
llm_input_ids = input_ids
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
# 多模態處理
image_features = None
if pixel_values is not None and past_key_values is None:
image_features = self.get_image_features(pixel_values)
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
else:
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
"tokens from image embeddings."
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = torch.where(special_image_mask, image_features, inputs_embeds).contiguous()
if input_audio_embeds is not None and past_key_values is None:
audio_features, audio_feat_mask = self.audio_projector(
input_audio_embeds, audio_attention_mask
)
if input_ids is None:
special_audio_mask = (
inputs_embeds
== self.get_input_embeddings()(
torch.tensor(
self.config.audio_token_index,
dtype=torch.long,
device=inputs_embeds.device,
)
)
)
else:
special_audio_mask = (
input_ids == self.config.audio_token_index
).unsqueeze(-1)
special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
if (
not is_torchdynamo_compiling()
and inputs_embeds[special_audio_mask].numel() != audio_features.numel()
):
audio_tokens_in_text = special_audio_mask.sum(dim=1).sum(dim=0)[0]
raise ValueError(
f"Number of audio tokens in the text ({audio_tokens_in_text}) "
f"≠ number of tokens from audio embeddings "
f"({audio_features.shape[0] * audio_features.shape[1]})."
)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
# Attention mask
if not isinstance((causal_mask_mapping := attention_mask), dict):
mask_kwargs = {
"config": self.config.get_text_config(),
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
}
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device))
from transformers.masking_utils import (
create_causal_mask,
create_sliding_window_causal_mask,
)
causal_mask_mapping = {
"full_attention": create_causal_mask(**mask_kwargs),
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
}
outputs = self.language_model(
attention_mask=causal_mask_mapping,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cache_position=cache_position,
**lm_kwargs,
)
return Gemma3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
# --------------------------------------------------------------------
# 5. 條件生成類(原版保留,僅呼叫上述主體)
# --------------------------------------------------------------------
class Gemma3OmniForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
config_class = Gemma3OmniConfig
_checkpoint_conversion_mapping = {
"^language_model.model": "model.language_model",
"^vision_tower": "model.vision_tower",
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = Gemma3OmniModel(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.get_input_embeddings()
def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
input_audio_embeds: Optional[torch.FloatTensor] = None,
audio_attention_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**lm_kwargs,
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
pixel_values=pixel_values,
input_audio_embeds=input_audio_embeds,
audio_attention_mask=audio_attention_mask,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**lm_kwargs,
)
hidden_states = outputs[0]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
logits = logits.float()
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
if attention_mask is not None:
shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device)
shift_logits = shift_logits[shift_attention_mask != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
shift_logits = shift_logits.contiguous()
shift_labels = shift_labels.contiguous()
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
flat_labels = shift_labels.view(-1).to(shift_logits.device)
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(flat_logits, flat_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Gemma3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states,
)
# --------------------------------------------------------------------
# 模組出口
# --------------------------------------------------------------------
__all__ = [
"Gemma3AudioProjectorConfig",
"Gemma3AudioProjector",
"Gemma3VisionProjector",
"Gemma3OmniModel",
"Gemma3OmniForConditionalGeneration",
]