|
|
|
""" |
|
Gemma3Omni – Whisper v3 多模態架構完整實作 |
|
-------------------------------------------------- |
|
包含音訊(Whisper v3 projector)、視覺、主模型與生成模型。 |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import torch |
|
from torch import nn |
|
from typing import List, Optional, Tuple, Union, Callable |
|
|
|
from transformers import ( |
|
AutoModel, |
|
AutoProcessor, |
|
WhisperModel, |
|
PreTrainedModel, |
|
PretrainedConfig, |
|
Cache, |
|
) |
|
from transformers.generation import GenerationMixin |
|
from transformers.models.gemma3.modeling_gemma3 import ( |
|
Gemma3TextScaledWordEmbedding as _OrigEmb, |
|
Gemma3RMSNorm, |
|
Gemma3PreTrainedModel, |
|
Gemma3ModelOutputWithPast, |
|
Gemma3CausalLMOutputWithPast, |
|
) |
|
from transformers.utils import ( |
|
is_torchdynamo_compiling, |
|
logging, |
|
is_torch_flex_attn_available, |
|
) |
|
from .configuration_gemma3_omni import Gemma3OmniConfig |
|
logger = logging.get_logger(__name__) |
|
|
|
if is_torch_flex_attn_available(): |
|
from torch.nn.attention.flex_attention import BlockMask |
|
|
|
|
|
|
|
|
|
class Gemma3AudioProjectorConfig(PretrainedConfig): |
|
model_type = "gemma3_audio_whisper" |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int = 2304, |
|
encoder_model_name: str = "openai/whisper-small", |
|
projector_num_layers: int = 3, |
|
freeze_whisper: bool = True, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.hidden_size = hidden_size |
|
self.encoder_model_name = encoder_model_name |
|
self.projector_num_layers = projector_num_layers |
|
self.freeze_whisper = freeze_whisper |
|
|
|
class TransformerProjector(nn.Module): |
|
def __init__(self, d_model: int, n_layers: int, n_head: int, dim_ff: int, out_dim: int): |
|
super().__init__() |
|
enc_layer = nn.TransformerEncoderLayer( |
|
d_model=d_model, |
|
nhead=n_head, |
|
dim_feedforward=dim_ff, |
|
dropout=0.1, |
|
activation="gelu", |
|
batch_first=True, |
|
) |
|
self.backbone = nn.TransformerEncoder(enc_layer, num_layers=n_layers) |
|
self.norm = nn.LayerNorm(d_model) |
|
self.out = nn.Linear(d_model, out_dim, bias=False) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.backbone(x) |
|
x = self.norm(x) |
|
return self.out(x) |
|
|
|
class Gemma3AudioProjector(PreTrainedModel): |
|
config_class = Gemma3AudioProjectorConfig |
|
base_model_prefix = "audio_projector" |
|
|
|
def __init__(self, config: Gemma3AudioProjectorConfig): |
|
super().__init__(config) |
|
self.processor = AutoProcessor.from_pretrained(config.encoder_model_name) |
|
whisper_full = WhisperModel.from_pretrained(config.encoder_model_name) |
|
self.whisper_encoder = whisper_full.get_encoder() |
|
if config.freeze_whisper: |
|
self.whisper_encoder.eval() |
|
for p in self.whisper_encoder.parameters(): |
|
p.requires_grad = False |
|
|
|
d_model = whisper_full.config.d_model |
|
n_head = whisper_full.config.encoder_attention_heads |
|
dim_ff = whisper_full.config.encoder_ffn_dim |
|
self.projector = TransformerProjector( |
|
d_model=d_model, |
|
n_layers=config.projector_num_layers, |
|
n_head=n_head, |
|
dim_ff=dim_ff, |
|
out_dim=config.hidden_size, |
|
) |
|
|
|
@torch.no_grad() |
|
def _encode_whisper(self, input_features: torch.Tensor, attention_mask: Optional[torch.Tensor]): |
|
return self.whisper_encoder( |
|
input_features=input_features, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
).last_hidden_state |
|
|
|
def forward( |
|
self, |
|
input_features: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
with torch.no_grad(): |
|
hidden = self._encode_whisper(input_features, attention_mask) |
|
projected = self.projector(hidden) |
|
out_mask = ( |
|
attention_mask |
|
if attention_mask is not None |
|
else torch.ones(projected.shape[:-1], dtype=torch.bool, device=projected.device) |
|
) |
|
return projected, out_mask |
|
|
|
|
|
|
|
|
|
class Gemma3VisionProjector(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.mm_input_projection_weight = nn.Parameter( |
|
torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size) |
|
) |
|
self.mm_soft_emb_norm = Gemma3RMSNorm( |
|
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps |
|
) |
|
self.patches_per_image = config.vision_config.image_size // config.vision_config.patch_size |
|
self.tokens_per_side = int(config.mm_tokens_per_image ** 0.5) |
|
self.kernel_size = self.patches_per_image // self.tokens_per_side |
|
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) |
|
|
|
def forward(self, vision_outputs: torch.Tensor): |
|
b, _, seq_len = vision_outputs.shape |
|
x = vision_outputs.transpose(1, 2).reshape( |
|
b, seq_len, self.patches_per_image, self.patches_per_image |
|
) |
|
x = self.avg_pool(x).flatten(2).transpose(1, 2) |
|
x = self.mm_soft_emb_norm(x) |
|
return torch.matmul(x, self.mm_input_projection_weight).type_as(vision_outputs) |
|
|
|
|
|
|
|
|
|
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]: |
|
if token_type_ids is None: |
|
return None |
|
|
|
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: |
|
return token_type_ids[batch_idx, kv_idx] != 0 |
|
return inner_mask |
|
|
|
|
|
|
|
|
|
class Gemma3OmniModel(Gemma3PreTrainedModel): |
|
config_class = None |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.vision_tower = AutoModel.from_config(config=config.vision_config) |
|
self.multi_modal_projector = Gemma3VisionProjector(config) |
|
self.audio_projector = Gemma3AudioProjector( |
|
Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size) |
|
) |
|
self.vocab_size = config.text_config.vocab_size |
|
|
|
language_model = AutoModel.from_config(config=config.text_config) |
|
self.language_model = language_model |
|
|
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.language_model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.language_model.set_input_embeddings(value) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
pixel_values: torch.FloatTensor = None, |
|
input_audio_embeds: Optional[torch.FloatTensor] = None, |
|
audio_attention_mask: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**lm_kwargs, |
|
) -> Union[Tuple, Gemma3ModelOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if input_ids is not None and self.config.image_token_id >= self.vocab_size: |
|
special_image_mask = input_ids == self.config.image_token_id |
|
llm_input_ids = input_ids.clone() |
|
llm_input_ids[special_image_mask] = 0 |
|
else: |
|
llm_input_ids = input_ids |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.get_input_embeddings()(llm_input_ids) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
|
|
|
|
image_features = None |
|
if pixel_values is not None and past_key_values is None: |
|
image_features = self.get_image_features(pixel_values) |
|
|
|
if input_ids is None: |
|
special_image_mask = inputs_embeds == self.get_input_embeddings()( |
|
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) |
|
) |
|
else: |
|
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) |
|
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
|
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): |
|
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] |
|
raise ValueError( |
|
f"Number of images does not match number of special image tokens in the input text. " |
|
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " |
|
"tokens from image embeddings." |
|
) |
|
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = torch.where(special_image_mask, image_features, inputs_embeds).contiguous() |
|
|
|
if input_audio_embeds is not None and past_key_values is None: |
|
audio_features, audio_feat_mask = self.audio_projector( |
|
input_audio_embeds, audio_attention_mask |
|
) |
|
if input_ids is None: |
|
special_audio_mask = ( |
|
inputs_embeds |
|
== self.get_input_embeddings()( |
|
torch.tensor( |
|
self.config.audio_token_index, |
|
dtype=torch.long, |
|
device=inputs_embeds.device, |
|
) |
|
) |
|
) |
|
else: |
|
special_audio_mask = ( |
|
input_ids == self.config.audio_token_index |
|
).unsqueeze(-1) |
|
special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to( |
|
inputs_embeds.device |
|
) |
|
if ( |
|
not is_torchdynamo_compiling() |
|
and inputs_embeds[special_audio_mask].numel() != audio_features.numel() |
|
): |
|
audio_tokens_in_text = special_audio_mask.sum(dim=1).sum(dim=0)[0] |
|
raise ValueError( |
|
f"Number of audio tokens in the text ({audio_tokens_in_text}) " |
|
f"≠ number of tokens from audio embeddings " |
|
f"({audio_features.shape[0] * audio_features.shape[1]})." |
|
) |
|
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features) |
|
|
|
|
|
if not isinstance((causal_mask_mapping := attention_mask), dict): |
|
mask_kwargs = { |
|
"config": self.config.get_text_config(), |
|
"input_embeds": inputs_embeds, |
|
"attention_mask": attention_mask, |
|
"position_ids": position_ids, |
|
"cache_position": cache_position, |
|
"past_key_values": past_key_values, |
|
} |
|
if token_type_ids is not None and inputs_embeds.shape[1] != 1: |
|
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(token_type_ids.to(cache_position.device)) |
|
|
|
from transformers.masking_utils import ( |
|
create_causal_mask, |
|
create_sliding_window_causal_mask, |
|
) |
|
|
|
causal_mask_mapping = { |
|
"full_attention": create_causal_mask(**mask_kwargs), |
|
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs), |
|
} |
|
|
|
outputs = self.language_model( |
|
attention_mask=causal_mask_mapping, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=True, |
|
cache_position=cache_position, |
|
**lm_kwargs, |
|
) |
|
|
|
return Gemma3ModelOutputWithPast( |
|
last_hidden_state=outputs.last_hidden_state, |
|
past_key_values=outputs.past_key_values if use_cache else None, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
image_hidden_states=image_features if pixel_values is not None else None, |
|
) |
|
|
|
|
|
|
|
|
|
class Gemma3OmniForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin): |
|
config_class = Gemma3OmniConfig |
|
|
|
_checkpoint_conversion_mapping = { |
|
"^language_model.model": "model.language_model", |
|
"^vision_tower": "model.vision_tower", |
|
"^multi_modal_projector": "model.multi_modal_projector", |
|
"^language_model.lm_head": "lm_head", |
|
} |
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = Gemma3OmniModel(config) |
|
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.model.set_input_embeddings(value) |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
pixel_values: torch.FloatTensor = None, |
|
input_audio_embeds: Optional[torch.FloatTensor] = None, |
|
audio_attention_mask: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**lm_kwargs, |
|
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
pixel_values=pixel_values, |
|
input_audio_embeds=input_audio_embeds, |
|
audio_attention_mask=audio_attention_mask, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
labels=labels, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
**lm_kwargs, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
logits = logits.float() |
|
shift_logits = logits[..., :-1, :] |
|
shift_labels = labels[..., 1:] |
|
|
|
if attention_mask is not None: |
|
shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device) |
|
shift_logits = shift_logits[shift_attention_mask != 0].contiguous() |
|
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() |
|
else: |
|
shift_logits = shift_logits.contiguous() |
|
shift_labels = shift_labels.contiguous() |
|
|
|
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) |
|
flat_labels = shift_labels.view(-1).to(shift_logits.device) |
|
|
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(flat_logits, flat_labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return Gemma3CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
image_hidden_states=outputs.image_hidden_states, |
|
) |
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
"Gemma3AudioProjectorConfig", |
|
"Gemma3AudioProjector", |
|
"Gemma3VisionProjector", |
|
"Gemma3OmniModel", |
|
"Gemma3OmniForConditionalGeneration", |
|
] |
|
|