# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # This file was automatically generated from src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py. # Do NOT edit this file manually as any edits will be overwritten by the generation of # the file from the modular. If any change should be done, please apply the change to the # modular_phi4_multimodal.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # Copyright 2025 Microsoft and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import warnings from functools import wraps from typing import Callable, List, Optional, Tuple, Union, Any import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.nn.init import _calculate_fan_in_and_fan_out from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, torch_int, ) from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig logger = logging.get_logger(__name__) def set_attribute_for_modules(module: "torch.nn.Module", key: str, value: Any): """ Set a value to a module and all submodules. """ setattr(module, key, value) for submodule in module.children(): set_attribute_for_modules(submodule, key, value) def del_attribute_from_modules(module: "torch.nn.Module", key: str): """ Delete a value from a module and all submodules. """ # because we might remove it previously in case it's a shared module, e.g. activation function if hasattr(module, key): delattr(module, key) for submodule in module.children(): del_attribute_from_modules(submodule, key) def can_return_tuple(func): """ Decorator to wrap model method, to call output.to_tuple() if return_dict=False passed as a kwarg or use_return_dict=False is set in the config. Note: output.to_tuple() convert output to tuple skipping all `None` values. """ @wraps(func) def wrapper(self, *args, **kwargs): is_requested_to_return_tuple = kwargs.pop("return_dict", True) is False is_configured_to_return_tuple = self.config.use_return_dict is False if hasattr(self, "config") else False # The following allows to convert output to tuple ONLY on top level forward call, # while internal modules of the model will return Output objects # to be able to use name-based attribute access in modeling code. # We will check if we are on top level module, if so, turn off to tuple conversion for all # underling calls. is_top_level_module = getattr(self, "_is_top_level_module", True) if is_configured_to_return_tuple and is_top_level_module: set_attribute_for_modules(self, "_is_top_level_module", False) try: output = func(self, *args, **kwargs) if is_requested_to_return_tuple or (is_configured_to_return_tuple and is_top_level_module): output = output.to_tuple() finally: # Remove the flag after the model forward call is finished. if is_configured_to_return_tuple and is_top_level_module: del_attribute_from_modules(self, "_is_top_level_module") return output return wrapper def dynamic_rope_update(rope_forward): """ Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE (i.e. a RoPE implementation that may recompute its frequencies in the forward pass). Args: rope_forward (Callable): The forward pass of the RoPE implementation. Returns: The decorated forward pass. """ def longrope_frequency_update(self, position_ids, device): """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" seq_len = torch.max(position_ids) + 1 if hasattr(self.config, "original_max_position_embeddings"): original_max_position_embeddings = self.config.original_max_position_embeddings else: original_max_position_embeddings = self.config.max_position_embeddings if seq_len > original_max_position_embeddings: if not hasattr(self, "long_inv_freq"): self.long_inv_freq, _ = self.rope_init_fn( self.config, device, seq_len=original_max_position_embeddings + 1 ) self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) else: # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) def dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @wraps(rope_forward) def wrapper(self, x, position_ids): if "dynamic" in self.rope_type: dynamic_frequency_update(self, position_ids, device=x.device) elif self.rope_type == "longrope": longrope_frequency_update(self, position_ids, device=x.device) return rope_forward(self, x, position_ids) return wrapper class Phi4MultimodalVisionMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states def simple_eager_attention_forward( module: nn.Module, query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights class Phi4MultimodalVisionAttention(nn.Module): def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.scaling = self.head_dim**-0.5 self.is_causal = True self.attention_dropout = config.attention_dropout self.k_proj = nn.Linear(config.hidden_size, config.hidden_size) self.v_proj = nn.Linear(config.hidden_size, config.hidden_size) self.q_proj = nn.Linear(config.hidden_size, config.hidden_size) self.out_proj = nn.Linear(config.hidden_size, config.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) attention_interface: Callable = simple_eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1) attn_output = self.out_proj(attn_output) return attn_output, attn_weights class Phi4MultimodalVisionEncoderLayer(nn.Module): def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.embed_dim = config.hidden_size self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.self_attn = Phi4MultimodalVisionAttention(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Phi4MultimodalVisionMLP(config) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, ) -> Tuple[torch.FloatTensor]: """ Args: hidden_states (`torch.FloatTensor`): Input to the layer of shape `(batch, seq_len, embed_dim)`. attention_mask (`torch.FloatTensor`): Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. output_attentions (`bool`, *optional*, defaults to `False`): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) hidden_states, attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class Phi4MultimodalVisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`Phi4MultimodalVisionEncoderLayer`]. Args: config: Phi4MultimodalVisionConfig """ def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList( [Phi4MultimodalVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False # Ignore copy @can_return_tuple def forward( self, inputs_embeds, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) def _trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2, ) # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.0)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) def trunc_normal_tf_( tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 ) -> torch.Tensor: """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 and the result is subsequently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value """ with torch.no_grad(): _trunc_normal_(tensor, 0, 1.0, a, b) tensor.mul_(std).add_(mean) def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) if mode == "fan_in": denom = fan_in elif mode == "fan_out": denom = fan_out elif mode == "fan_avg": denom = (fan_in + fan_out) / 2 variance = scale / denom if distribution == "truncated_normal": # constant is stddev of standard normal truncated to (-2, 2) trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) elif distribution == "normal": with torch.no_grad(): tensor.normal_(std=math.sqrt(variance)) elif distribution == "uniform": bound = math.sqrt(3 * variance) with torch.no_grad(): tensor.uniform_(-bound, bound) else: raise ValueError(f"invalid distribution {distribution}") def lecun_normal_(tensor): variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") class Phi4MultimodalVisionPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = Phi4MultimodalVisionConfig base_model_prefix = "phi4_vision" supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalVisionEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" if isinstance(module, Phi4MultimodalVisionEmbeddings): width = ( self.config.hidden_size if isinstance(self.config, Phi4MultimodalVisionConfig) else self.config.hidden_size ) nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) elif isinstance(module, nn.Embedding): default_flax_embed_init(module.weight) elif isinstance(module, Phi4MultimodalVisionAttention): nn.init.normal_(module.q_proj.weight) nn.init.normal_(module.k_proj.weight) nn.init.normal_(module.v_proj.weight) nn.init.normal_(module.out_proj.weight) nn.init.zeros_(module.q_proj.bias) nn.init.zeros_(module.k_proj.bias) nn.init.zeros_(module.v_proj.bias) nn.init.zeros_(module.out_proj.bias) elif isinstance(module, Phi4MultimodalVisionMLP): nn.init.normal_(module.fc1.weight) nn.init.normal_(module.fc2.weight) nn.init.normal_(module.fc1.bias, std=1e-6) nn.init.normal_(module.fc2.bias, std=1e-6) elif isinstance(module, Phi4MultimodalVisionMultiheadAttentionPoolingHead): nn.init.normal_(module.probe.data) nn.init.normal_(module.attention.in_proj_weight.data) nn.init.zeros_(module.attention.in_proj_bias.data) elif isinstance(module, (nn.Linear, nn.Conv2d)): lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) class Phi4MultimodalVisionEmbeddings(nn.Module): def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.config = config self.patch_size = config.patch_size self.num_patches_per_side = config.image_size // self.patch_size self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, padding="valid", ) self.position_embedding = nn.Embedding(self.num_patches_per_side**2, config.hidden_size) def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. This method is also adapted to support torch.jit tracing and no class embeddings. Adapted from: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 """ num_patches = embeddings.shape[1] num_positions = self.position_embedding.weight.shape[0] # always interpolate when tracing to ensure the exported model works for dynamic input shapes if not torch.jit.is_tracing() and num_patches == num_positions and height == width: return self.position_embedding(self.position_ids) patch_pos_embed = self.position_embedding.weight.unsqueeze(0) dim = embeddings.shape[-1] new_height = height // self.patch_size new_width = width // self.patch_size sqrt_num_positions = torch_int(num_positions**0.5) patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, size=(new_height, new_width), mode="bicubic", align_corners=False, ) patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return patch_pos_embed def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: batch_size = pixel_values.size(0) patch_embeds = self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) position_ids = torch.full((batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): nb_patches_h = p_attn_mask[:, 0].sum() nb_patches_w = p_attn_mask[0].sum() fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings class Phi4MultimodalVisionMultiheadAttentionPoolingHead(nn.Module): """Multihead Attention Pooling.""" def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__() self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = Phi4MultimodalVisionMLP(config) def forward(self, hidden_state, attention_mask): batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) hidden_state = self.attention( query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask )[0] residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) return hidden_state[:, 0] class Phi4MultimodalVisionModel(Phi4MultimodalVisionPreTrainedModel): config_class = Phi4MultimodalVisionConfig main_input_name = "pixel_values" def __init__(self, config: Phi4MultimodalVisionConfig): super().__init__(config) self.config = config self.embeddings = Phi4MultimodalVisionEmbeddings(config) self.encoder = Phi4MultimodalVisionEncoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.head = Phi4MultimodalVisionMultiheadAttentionPoolingHead(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.embeddings.patch_embedding def forward( self, pixel_values, patch_attention_mask: Optional[torch.BoolTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutputWithPooling: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) batch_size = pixel_values.size(0) if patch_attention_mask is None: patch_attention_mask = torch.ones( size=( batch_size, pixel_values.size(2) // self.config.patch_size, pixel_values.size(3) // self.config.patch_size, ), dtype=torch.bool, device=pixel_values.device, ) hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) patch_attention_mask = patch_attention_mask.view(batch_size, -1) # The call to `_upad_input` in `_flash_attention_forward` is expensive # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), # avoiding passing the attention_mask, which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): attention_mask = None else: attention_mask = ( _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) if not self.config._attn_implementation == "flash_attention_2" else patch_attention_mask ) encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) pooled_output = self.head( hidden_state=last_hidden_state, attention_mask=patch_attention_mask, ) return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class Phi4MultimodalImageEmbedding(nn.Module): """Image embedding.""" def __init__(self, config: Phi4MultimodalConfig): super().__init__() self.config = config self.layer_idx = config.vision_config.feature_layer self.crop_size = config.vision_config.crop_size self.image_dim_out = config.vision_config.hidden_size n_patches = config.vision_config.image_size // config.vision_config.patch_size if n_patches % 2 != 0: self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) n_patches += 1 self.num_img_tokens = (n_patches // 2) ** 2 self.drop = nn.Dropout(config.embd_pdrop) self.img_processor = Phi4MultimodalVisionModel._from_config(config.vision_config) self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) self.img_projection_up = nn.Linear(self.image_dim_out, config.hidden_size) self.img_projection_down = nn.Linear(config.hidden_size, config.hidden_size) self.global_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, self.image_dim_out])) self.sub_img_feature_extensor = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out])) def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: img_processor_output = self.img_processor( img_embeds, patch_attention_mask=attention_mask, output_hidden_states=True ) img_feature = img_processor_output.hidden_states[self.layer_idx] patch_feature = img_feature # reshape to 2D tensor width = int(math.sqrt(patch_feature.size(1))) patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) # convert to NCHW patch_feature = patch_feature.permute(0, 3, 1, 2) if getattr(self, "img_processor_padding", None) is not None: patch_feature = self.img_processor_padding(patch_feature) patch_feature = self.image_token_compression(patch_feature) # convert to NHWC patch_feature = patch_feature.permute(0, 2, 3, 1) patch_feature = patch_feature.view(-1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1)) return patch_feature def forward( self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_pixel_values: torch.FloatTensor, image_sizes: Optional[torch.Tensor] = None, image_attention_mask: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: image_pixel_values = image_pixel_values.to(self.img_processor.embeddings.patch_embedding.weight.dtype) target_device = self.img_projection_up.bias.device target_dtype = self.img_projection_up.bias.dtype batch_size = image_pixel_values.shape[0] img_features = self.get_img_features( image_pixel_values.flatten(0, 1), attention_mask=image_attention_mask.flatten(0, 1).to(dtype=bool, device=target_device), ) base_feat_size = int(np.sqrt(img_features.shape[1])) img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out) image_sizes = image_sizes.view(-1, 2) output_imgs = [] for idx in range(batch_size): height, width = image_sizes[idx] height_ratio = height // self.crop_size width_ratio = width // self.crop_size area_ratio = height_ratio * width_ratio global_img = img_features[idx, :1] global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous() temporary_extensor = self.sub_img_feature_extensor.repeat(1, base_feat_size, 1, 1) global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) sub_img = img_features[idx, 1:] sub_img = sub_img[:area_ratio] sub_img = ( sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out) .transpose(1, 2) .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out) .contiguous() ) if image_attention_mask is not None: reshaped_image_attention_mask = ( image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) .transpose(1, 2) .reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size) ) useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) sub_img = sub_img[:, :useful_height, :useful_width] temporary_extensor = self.sub_img_feature_extensor.repeat(1, useful_height, 1, 1) else: temporary_extensor = self.sub_img_feature_extensor.repeat(1, height_ratio * base_feat_size, 1, 1) sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out) # Merge global and sub output_imgs.append(torch.cat([sub_img, self.global_img_feature_extensor, global_img], dim=1)) img_set_tensor = [] for output_img in output_imgs: output_img = output_img.to(device=target_device, dtype=target_dtype) img_feature_proj = self.img_projection_up(output_img) img_feature_proj = nn.functional.gelu(img_feature_proj) img_feature_proj = self.img_projection_down(img_feature_proj) img_set_tensor.append(img_feature_proj) merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0) merged_img_set_tensor = merged_img_set_tensor.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) with torch.no_grad(): positions_tuple = torch.nonzero(input_ids == self.config.vision_config.image_token_id, as_tuple=True) # Temporarily disable autocast to avoid issue on bf16 tensors # Ref: https://github.com/pytorch/pytorch/issues/132715 with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): image_embeds = inputs_embeds.index_put( indices=positions_tuple, values=merged_img_set_tensor, accumulate=False ) image_embeds = self.drop(image_embeds) return image_embeds ########################################################## AUDIO ############################################# class Phi4MultimodalAudioMLP(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.layer_norm = nn.LayerNorm(config.hidden_size) self.act_fn = ACT2FN[config.activation] self.gate_up_proj = nn.Linear(config.hidden_size, config.intermediate_size * 2) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states): hidden_states = self.layer_norm(hidden_states) up_states = self.gate_up_proj(hidden_states) up_states, gate = up_states.chunk(2, dim=-1) up_states = up_states * self.act_fn(gate) up_states = self.dropout(up_states) hidden_states = self.down_proj(up_states) out = self.dropout(hidden_states) return out class Phi4MultimodalAudioAttention(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.scaling = self.head_dim**-0.5 self.attention_dropout = config.dropout_rate self.is_causal = True self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=True) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs, ): input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) attention_interface: Callable = simple_eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, _ = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output class Phi4MultimodalAudioDepthWiseSeperableConv1d(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig, padding: int = 0): super().__init__() self.dw_conv = nn.Conv1d( config.hidden_size, config.hidden_size * config.depthwise_multiplier, config.kernel_size, 1, padding=padding, groups=config.hidden_size, ) self.pw_conv = nn.Conv1d( config.hidden_size * config.depthwise_multiplier, config.depthwise_seperable_out_channel, 1, 1, 0 ) def forward(self, hidden_states): return self.pw_conv(self.dw_conv(hidden_states)) class Phi4MultimodalAudioGluPointWiseConv(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.output_dim = config.ext_pw_out_channel self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel * 2, kernel_size=1, stride=1) self.glu_act = ACT2FN[config.conv_glu_type] self.b1 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) self.b2 = nn.Parameter(torch.zeros(1, config.ext_pw_out_channel, 1)) def forward(self, hidden_states): # we assume the input always has the #channel (#dim) in the last dimension of the # tensor, so need to switch the dimension first for 1D-Conv case hidden_states = hidden_states.permute([0, 2, 1]) hidden_states = self.ext_pw_conv_1d(hidden_states) out = hidden_states[:, 0 : self.output_dim, :] + self.b1 out = out * self.glu_act(hidden_states[:, self.output_dim : self.output_dim * 2, :] + self.b2) return out.permute([0, 2, 1]) class Phi4MultimodalAudioConvModule(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.config = config self.kernel_size = config.kernel_size self.layer_norm = nn.LayerNorm(config.hidden_size) self.glu = Phi4MultimodalAudioGluPointWiseConv(config) self.dw_sep_conv_1d = Phi4MultimodalAudioDepthWiseSeperableConv1d(config, padding=config.kernel_size - 1) self.act = ACT2FN[config.conv_activation] self.ext_pw_conv_1d = nn.Conv1d(config.hidden_size, config.ext_pw_out_channel, kernel_size=1, stride=1) self.dropout = nn.Dropout(config.dropout_rate) def forward(self, hidden_states: torch.Tensor): hidden_states = self.glu(self.layer_norm(hidden_states)) hidden_states = self.dw_sep_conv_1d(hidden_states.permute([0, 2, 1])) if self.kernel_size > 1: hidden_states = hidden_states[:, :, : -(self.kernel_size - 1)] hidden_states = self.act(hidden_states) hidden_states = self.ext_pw_conv_1d(hidden_states) out = self.dropout(hidden_states.permute([0, 2, 1])) return out class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.feed_forward_in = Phi4MultimodalAudioMLP(config) self.self_attn = Phi4MultimodalAudioAttention(config) self.conv = Phi4MultimodalAudioConvModule(config) self.feed_forward_out = Phi4MultimodalAudioMLP(config) self.layer_norm_att = nn.LayerNorm(config.hidden_size) self.layer_norm = nn.LayerNorm(config.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, ): residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) hidden_states = self.layer_norm_att(residual) hidden_states = residual + self.self_attn(hidden_states, attention_mask) hidden_states = hidden_states + self.conv(hidden_states) hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) out = self.layer_norm(hidden_states) return out class Phi4MultimodalAudioNemoConvSubsampling(torch.nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.subsampling_factor = config.time_reduction self.sampling_num = int(math.log(self.subsampling_factor, 2)) self.act_fn = ACT2FN[config.nemo_activation] conv_channels = config.nemo_conv_channels layers = [ nn.Conv2d(1, conv_channels, kernel_size=3, stride=2, padding=1), self.act_fn, ] for _ in range(self.sampling_num - 1): layers.extend( [ nn.Conv2d(conv_channels, conv_channels, kernel_size=3, stride=2, padding=1, groups=conv_channels), nn.Conv2d(conv_channels, conv_channels, kernel_size=1, stride=1, padding=0, groups=1), self.act_fn, ] ) # Aggregate the layers self.conv = torch.nn.Sequential(*layers) self.out = torch.nn.Linear(conv_channels * config.nemo_final_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): # Unsqueeze Channel Axis hidden_states = hidden_states.unsqueeze(1) hidden_states = self.conv(hidden_states) # Flatten Channel and Frequency Axes b, _, t, _ = hidden_states.size() hidden_states = self.out(hidden_states.transpose(1, 2).reshape(b, t, -1)) if mask is None: return hidden_states, None max_audio_length = hidden_states.shape[1] feature_lens = mask.sum(1) padding_length = torch.ceil(feature_lens / self.subsampling_factor) arange_ = torch.arange(0, max_audio_length, device=hidden_states.device) pad_mask = arange_.expand(padding_length.size(0), -1) < padding_length.unsqueeze(1) return hidden_states, pad_mask.unsqueeze(1) class Phi4MultimodalAudioRelativeAttentionBias(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.max_distance = config.bias_max_distance self.symmetric = config.bias_symmetric self.num_buckets = self.max_distance if not config.bias_symmetric: self.num_buckets *= 2 self.bias_values = nn.Embedding(self.num_buckets, config.num_attention_heads) def forward(self, x): # instantiate bias compatible with shape of x max_pos = x.size(1) context_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[:, None] memory_position = torch.arange(max_pos, device=x.device, dtype=torch.long)[None, :] relative_position = memory_position - context_position # clipping to a maximum distance using ops that play well with ONNX export relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance) relative_position = relative_position.masked_fill( relative_position > self.max_distance - 1, self.max_distance - 1 ) # mapping from relative position to index in the bias parameter bias_idx = relative_position bias_idx = bias_idx.abs() if self.symmetric else bias_idx + self.num_buckets // 2 att_bias = self.bias_values(bias_idx) att_bias = att_bias.permute(2, 0, 1).unsqueeze(0) return att_bias class Phi4MultimodalAudioMeanVarianceNormLayer(nn.Module): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__() self.register_buffer("global_mean", torch.zeros(config.input_size)) self.register_buffer("global_invstd", torch.ones(config.input_size)) def forward(self, x): return (x - self.global_mean) * self.global_invstd class Phi4MultimodalAudioPreTrainedModel(PreTrainedModel): config_class = Phi4MultimodalAudioConfig supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalAudioConformerEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, Phi4MultimodalAudioGluPointWiseConv): module.b1.data.zero_() module.b2.data.zero_() def unfold_tensor(tensor, max_seq_len): """ For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. Args: tensor: N, T, D """ _, _, D = tensor.shape tensor = tensor.transpose(-1, -2) # N x D x 1 x T => N x (D x max_seq_len) x T' tensor = F.unfold(tensor[..., None, :], kernel_size=(1, max_seq_len), stride=(1, max_seq_len)) new_bsz, _, slen = tensor.shape tensor = tensor.view(new_bsz, -1, max_seq_len, slen) tensor = tensor.permute(0, 3, 2, 1) tensor = tensor.view(-1, max_seq_len, D).contiguous() return tensor def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): """ The function is very important for Transformer Transducer Streaming mode Args: xs_len (int): sequence length chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] left_window (int): how many left chunks can be seen right_window (int): how many right chunks can be seen. It is used for chunk overlap model. Returns: mask (torch.Tensor): a mask tensor for streaming model """ chunk_start_idx = torch.Tensor(chunk_start_idx).long() start_pad = torch.nn.functional.pad( chunk_start_idx, (1, 0) ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] end_pad = torch.nn.functional.pad( chunk_start_idx, (0, 1), value=x_len ) # append x_len to the end, so it becomes [0,18,36,48, x_len] seq_range = torch.arange(0, x_len).unsqueeze(-1) idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) idx_left = idx - left_window idx_left[idx_left < 0] = 0 boundary_left = start_pad[idx_left] mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) idx_right = idx + right_window idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) boundary_right = end_pad[idx_right] mask_right = seq_range_expand < boundary_right.unsqueeze(-1) return mask_left & mask_right class Phi4MultimodalAudioModel(Phi4MultimodalAudioPreTrainedModel): def __init__(self, config: Phi4MultimodalAudioConfig): super().__init__(config) self.config = config self.encoder_embedding = Phi4MultimodalAudioMeanVarianceNormLayer(config) self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias(config) self.encoders = nn.ModuleList( [Phi4MultimodalAudioConformerEncoderLayer(config) for _ in range(config.num_blocks)] ) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): # Create mask matrix for streaming # S stores start index. if chunksize is 18, s is [0,18,36,....] chunk_start_idx = np.arange(0, seq_len, chunk_size) # avoid randomness when run evaluation or decoding if self.training and np.random.rand() > 0.5: # Either first or last chunk is not complete. # If only the last one is not complete, EOS is not effective chunk_start_idx = seq_len - chunk_start_idx chunk_start_idx = chunk_start_idx[::-1] chunk_start_idx = chunk_start_idx[:-1] chunk_start_idx = np.insert(chunk_start_idx, 0, 0) enc_streaming_mask = ( adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) .unsqueeze(0) .expand([batch_size, -1, -1]) ) return enc_streaming_mask def forward_embeddings(self, hidden_states, masks): """Forwarding the inputs through the top embedding layers""" seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) if seq_len <= 0: raise ValueError( f"The squence length after time reduction is invalid: {seq_len}. Your input feature is too short." ) batch_size = hidden_states.shape[0] enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.config.chunk_size, self.config.left_chunk) enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) hidden_states, masks = self.embed(hidden_states, masks) streaming_mask = enc_streaming_mask if streaming_mask is not None and masks is not None: hs_mask = masks & streaming_mask elif masks is not None: hs_mask = masks else: hs_mask = streaming_mask return hidden_states, hs_mask, masks def calculate_hs_mask(self, hidden_states, device, mask): max_audio_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] enc_streaming_mask = self._streaming_mask( max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk ) enc_streaming_mask = enc_streaming_mask.to(device) if mask is None: return enc_streaming_mask feature_lens = mask.sum(1) padding_length = feature_lens pad_mask = torch.arange(0, max_audio_length, device=device).expand( padding_length.size(0), -1 ) < padding_length.unsqueeze(1) pad_mask = pad_mask.unsqueeze(1) pad_mask = pad_mask & enc_streaming_mask return pad_mask def forward(self, hidden_states: torch.Tensor, mask: Optional[torch.Tensor]): hidden_states = self.encoder_embedding(hidden_states) hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) unfolded = False bs, seq_len, _ = hidden_states.shape max_seq_len = 500 # maxium position for absolute positional encoding if seq_len > max_seq_len: # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len unfolded = True # the unfold op will drop residual frames, pad it to the multiple of max_seq_len if seq_len % max_seq_len > 0: chunk_pad_size = max_seq_len - (seq_len % max_seq_len) else: chunk_pad_size = 0 if chunk_pad_size > 0: hidden_states_pad = F.pad(hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0) hidden_states = hidden_states_pad.to(hidden_states.device) hidden_states = unfold_tensor(hidden_states, max_seq_len) masks_unfold = None if mask is not None: # revise hs_mask here because the previous calculated hs_mask did not consider extra pad subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] extra_padded_subsamlped_pad_mask = F.pad( subsampled_pad_mask, (0, chunk_pad_size), "constant", False ) # extra padding to the pad mask extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() masks_unfold = unfold_tensor( extra_padded_subsamlped_pad_mask, max_seq_len ) # unfold the pad mask like we did to the input tensor masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor hs_mask = self.calculate_hs_mask( hidden_states, hidden_states.device, masks_unfold ) # calculate hs_mask based on the unfolded pad mask relative_attention_bias = self.relative_attention_bias_layer(hidden_states) attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias for layer in self.encoders: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( layer.__call__, hidden_states, attention_mask, ) else: hidden_states = layer(hidden_states, attention_mask) if unfolded: embed_dim = hidden_states.shape[-1] hidden_states = hidden_states.reshape(bs, -1, embed_dim) # if we ever padded before unfolding, we need to remove the padding if chunk_pad_size > 0: hidden_states = hidden_states[:, :-chunk_pad_size, :] return hidden_states class Phi4MultimodalAudioEmbedding(nn.Module): def __init__(self, config: Phi4MultimodalConfig): super().__init__() self.config = config self.layer_idx = config.audio_config.feature_layer self.drop = nn.Dropout(config.embd_pdrop) self.encoder = Phi4MultimodalAudioModel._from_config(config.audio_config) self.up_proj_for_speech = nn.Linear( config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size ) self.down_proj_for_speech = nn.Linear(config.hidden_size, config.hidden_size) self.up_proj_for_vision_speech = nn.Linear( config.audio_config.hidden_size * config.audio_config.downsample_rate, config.hidden_size ) self.down_proj_for_vision_speech = nn.Linear(config.hidden_size, config.hidden_size) def forward( self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, audio_input_features: torch.FloatTensor, audio_embed_sizes=None, audio_attention_mask=None, audio_projection_mode="speech", ) -> torch.FloatTensor: with torch.no_grad(): positions_tuple = torch.nonzero(input_ids == self.config.audio_config.audio_token_id, as_tuple=True) up_proj = self.up_proj_for_speech if audio_projection_mode == "speech" else self.up_proj_for_vision_speech down_proj = ( self.down_proj_for_speech if audio_projection_mode == "speech" else self.down_proj_for_vision_speech ) target_device = up_proj.bias.device target_dtype = up_proj.bias.dtype audio_input_features = audio_input_features.to(device=target_device, dtype=target_dtype) audio_encoder_hidden_states = self.encoder(audio_input_features, audio_attention_mask) audio_encoder_hidden_states = up_proj(audio_encoder_hidden_states) audio_encoder_hidden_states = nn.functional.gelu(audio_encoder_hidden_states) audio_embeds = down_proj(audio_encoder_hidden_states) merged_audio_embeds = torch.cat( [audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0 ) merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) # Temporarily disable autocast to avoid issue on bf16 tensors # Ref: https://github.com/pytorch/pytorch/issues/132715 with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): audio_embeds = inputs_embeds.index_put( indices=positions_tuple, values=merged_audio_embeds, accumulate=False ) audio_embeds = self.drop(audio_embeds) return audio_embeds class Phi4MultimodalRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Phi4MultimodalRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" class Phi4MultimodalMLP(nn.Module): def __init__(self, config): super().__init__() self.config = config self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) self.activation_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: up_states = self.gate_up_proj(hidden_states) gate, up_states = up_states.chunk(2, dim=-1) up_states = up_states * self.activation_fn(gate) return self.down_proj(up_states) def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], scaling: float, dropout: float = 0.0, **kwargs, ): key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() return attn_output, attn_weights def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`, *optional*): Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) return q_embed, k_embed class Phi4MultimodalAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Phi4MultimodalConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads self.num_key_value_heads = config.num_key_value_heads self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True op_size = config.num_attention_heads * self.head_dim + 2 * (config.num_key_value_heads * self.head_dim) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) def forward( self, hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) qkv = self.qkv_proj(hidden_states) query_pos = self.config.num_attention_heads * self.head_dim query_states = qkv[..., :query_pos] key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): logger.warning_once( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=getattr(self.config, "sliding_window", None), **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Phi4MultimodalDecoderLayer(nn.Module): def __init__(self, config: Phi4MultimodalConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Phi4MultimodalAttention(config=config, layer_idx=layer_idx) self.mlp = Phi4MultimodalMLP(config) self.input_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.config = config self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. position_ids (`torch.LongTensor` of shape `({0})`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_value (`Cache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = residual + self.resid_attn_dropout(hidden_states) # main diff with Llama residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + self.resid_mlp_dropout(hidden_states) # main diff with Llama outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) return outputs class Phi4MultimodalFeatureEmbedding(nn.Module): """Image-audio embedding.""" def __init__(self, config: Phi4MultimodalConfig) -> None: super().__init__() self.config = config self.image_token_id = config.vision_config.image_token_id self.audio_token_id = config.audio_config.audio_token_id self.image_embed = Phi4MultimodalImageEmbedding(config) self.audio_embed = Phi4MultimodalAudioEmbedding(config) def forward( self, input_ids: torch.LongTensor, inputs_embeds: torch.Tensor, image_pixel_values: Optional[torch.FloatTensor] = None, audio_input_features: Optional[torch.FloatTensor] = None, image_sizes=None, image_attention_mask=None, audio_embed_sizes=None, audio_attention_mask=None, ) -> torch.FloatTensor: with torch.no_grad(): image_position_mask = (input_ids == self.config.vision_config.image_token_id).unsqueeze(-1) non_image_position_mask = ~image_position_mask image_embeds = None audio_embeds = None if image_pixel_values is not None and (input_ids == self.image_token_id).any(): image_embeds = self.image_embed( input_ids, inputs_embeds, image_pixel_values=image_pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, ) if audio_input_features is not None and (input_ids == self.audio_token_id).any(): audio_projection_mode = "vision" if image_pixel_values is not None else "speech" audio_embeds = self.audio_embed( input_ids, inputs_embeds, audio_input_features=audio_input_features, audio_embed_sizes=audio_embed_sizes, audio_attention_mask=audio_attention_mask, audio_projection_mode=audio_projection_mode, ) # merge image and audio if image_embeds is not None and audio_embeds is not None: inputs_embeds = image_embeds * image_position_mask + audio_embeds * non_image_position_mask elif image_embeds is not None: inputs_embeds = image_embeds elif audio_embeds is not None: inputs_embeds = audio_embeds return inputs_embeds class Phi4MultimodalRotaryEmbedding(nn.Module): def __init__(self, config: Phi4MultimodalConfig, device=None): super().__init__() # BC: "rope_type" was originally "type" if hasattr(config, "rope_scaling") and config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq @torch.no_grad() @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) PHI4_MULTIMODAL_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`Phi4MultimodalConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.", PHI4_MULTIMODAL_START_DOCSTRING, ) class Phi4MultimodalPreTrainedModel(PreTrainedModel): config_class = Phi4MultimodalConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["Phi4MultimodalDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True _supports_attention_backend = True _version = "0.0.5" def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, Phi4MultimodalRMSNorm): module.weight.data.fill_(1.0) elif isinstance(module, Phi4MultimodalImageEmbedding): module.global_img_feature_extensor.data.zero_() module.sub_img_feature_extensor.data.zero_() PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache`)`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. See our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. image_pixel_values (`torch.FloatTensor`, *optional*): If the input contains images, these correspond to the pixel values after transformations (as returned by the Processor) image_sizes (`torch.LongTensor`, *optional*): If the input contains images, these correspond to size of each image. image_attention_mask (`torch.LongTensor`, *optional*): Attention mask for the images. audio_input_features (`torch.FloatTensor`, *optional*): If the input contains audio samples, these correspond to the values after transformation (as returned by the Processor). audio_embed_sizes (`torch.Tensor`, *optional*): Size of the audio inputs. audio_attention_mask (`torch.Tensor, *optional*): Attention mask for the audio inputs. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. """ @add_start_docstrings( "The bare Phi4Multimodal Model outputting raw hidden-states without any specific head on top.", PHI4_MULTIMODAL_START_DOCSTRING, ) class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi4MultimodalMMDecoderLayer`] Args: config: Phi4MultimodalMMConfig """ def __init__(self, config: Phi4MultimodalConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [Phi4MultimodalDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.norm = Phi4MultimodalRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Phi4MultimodalRotaryEmbedding(config=config) self.gradient_checkpointing = False self.embed_dropout = nn.Dropout(config.embd_pdrop) self.embed_tokens_extend = Phi4MultimodalFeatureEmbedding(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @can_return_tuple @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, image_pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, image_attention_mask=None, audio_input_features: Optional[torch.FloatTensor] = None, audio_embed_sizes=None, audio_attention_mask=None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> BaseModelOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if use_cache and past_key_values is None: past_key_values = DynamicCache() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens_extend( input_ids, inputs_embeds, image_pixel_values=image_pixel_values, audio_input_features=audio_input_features, image_sizes=image_sizes, image_attention_mask=image_attention_mask, audio_embed_sizes=audio_embed_sizes, audio_attention_mask=audio_attention_mask, ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, ) hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _update_causal_mask( self, attention_mask: Union[torch.Tensor, "BlockMask"], input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool = False, ): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" " this may lead to unexpected behaviour for Flash Attention version of Phi4Multimodal. Make sure to " " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) return attention_mask # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward if ( self.config._attn_implementation == "sdpa" and not (using_static_cache or using_sliding_window_cache) and not output_attentions ): if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, sliding_window=self.config.sliding_window, is_training=self.training, ): return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] # SlidingWindowCache or StaticCache if using_sliding_window_cache or using_static_cache: target_length = past_key_values.get_max_cache_shape() # DynamicCache or no cache else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, dtype=dtype, device=device, cache_position=cache_position, batch_size=input_tensor.shape[0], config=self.config, past_key_values=past_key_values, ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @staticmethod def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int, config: Phi4MultimodalConfig, past_key_values: Cache, ): """ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. Args: attention_mask (`torch.Tensor`): A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. sequence_length (`int`): The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): The device to place the 4D attention mask on. cache_position (`torch.Tensor`): Indices depicting the position of the input sequence tokens in the sequence. batch_size (`torch.Tensor`): Batch size. config (`Phi4MultimodalConfig`): The model's configuration class past_key_values (`Cache`): The cache class that is being used currently to generate """ if attention_mask is not None and attention_mask.dim() == 4: # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. causal_mask = attention_mask else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) if config.get_text_config().sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: sliding_attend_mask = torch.arange(target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.get_text_config().sliding_window ) diagonal_attend_mask.bitwise_or_(sliding_attend_mask) causal_mask *= diagonal_attend_mask causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( causal_mask.device ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) return causal_mask class Phi4MultimodalForCausalLM(Phi4MultimodalPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] _tp_plan = {"lm_head": "colwise_rep"} _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} def __init__(self, config): super().__init__(config) self.model = Phi4MultimodalModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model @can_return_tuple @add_start_docstrings_to_model_forward(PHI4_MULTIMODAL_MODEL_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=Phi4MultimodalConfig) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, image_pixel_values: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.LongTensor] = None, image_attention_mask=None, audio_input_features: Optional[torch.FloatTensor] = None, audio_embed_sizes=None, audio_attention_mask=None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs, ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python >>> from transformers import AutoTokenizer, Phi4MultimodalForCausalLM >>> model = Phi4MultimodalForCausalLM.from_pretrained("TBA") >>> tokenizer = AutoTokenizer.from_pretrained("TBA") >>> prompt = "This is an example script ." >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, image_pixel_values=image_pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, audio_input_features=audio_input_features, audio_embed_sizes=audio_embed_sizes, audio_attention_mask=audio_attention_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits, labels, self.vocab_size) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, image_pixel_values=None, image_sizes=None, image_attention_mask=None, audio_input_features=None, audio_embed_sizes=None, audio_attention_mask=None, cache_position=None, position_ids=None, use_cache=True, logits_to_keep=0, **kwargs, ): # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the # process # When the first time input length reached long and short factor switching point, enforce re-compute cache # It will cause downside of slower at this single token position, however, better than current failure. if ( past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 ): past_length = cache_position[0] if past_length <= self.config.original_max_position_embeddings: past_key_values = None model_inputs = super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, image_pixel_values=image_pixel_values, image_sizes=image_sizes, image_attention_mask=image_attention_mask, audio_input_features=audio_input_features, audio_embed_sizes=audio_embed_sizes, audio_attention_mask=audio_attention_mask, cache_position=cache_position, position_ids=position_ids, use_cache=use_cache, logits_to_keep=logits_to_keep, **kwargs, ) return model_inputs __all__ = [ "Phi4MultimodalAudioPreTrainedModel", "Phi4MultimodalAudioModel", "Phi4MultimodalVisionPreTrainedModel", "Phi4MultimodalVisionModel", "Phi4MultimodalPreTrainedModel", "Phi4MultimodalModel", "Phi4MultimodalForCausalLM", ] Phi4MultimodalForCausalLM.register_for_auto_class("AutoModelForCausalLM")