voidful commited on
Commit
e9f54ae
·
verified ·
1 Parent(s): f0fe779

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
File without changes
chat_template.jinja ADDED
@@ -0,0 +1 @@
 
 
1
+ {{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' %}{% set loop_messages = messages[1:] %}{% else %}{% set first_user_prefix = '' %}{% set loop_messages = messages %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = 'model' if message['role'] == 'assistant' else message['role'] %}{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else '') }}{% if role == 'model' and message.get('metadata') %}{% if message['metadata']['type'] == 'think' %}<think>{% if message['metadata'].get('range') %}<range>{{ message['metadata']['range'] }}</range>{% endif %}{% if message['metadata'].get('content') %}{{ message['metadata']['content'] | trim }}{% endif %}</think>{% elif message['metadata']['type'] == 'direct' %}<direct>{% endif %}{% if message['metadata'].get('function') %}<function>{{ message['metadata']['function'] | join(',') }}</function>{% endif %}{% endif %}{% if message['content'] is string %}{{ message['content'] | trim }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{{ '<start_of_image>' if item['type']=='image' else '<start_of_audio>' if item['type']=='audio' else item['text'] | trim if item['type']=='text' else '' }}{% endfor %}{% else %}{{ raise_exception('Invalid content type') }}{% endif %}{{ '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}
config.json CHANGED
@@ -2,7 +2,13 @@
2
  "architectures": [
3
  "Gemma3OmniForConditionalGeneration"
4
  ],
5
- "audio_token_index": 262151,
 
 
 
 
 
 
6
  "boi_token_index": 255999,
7
  "eoi_token_index": 256000,
8
  "eos_token_id": [
@@ -122,4 +128,4 @@
122
  "torch_dtype": "float32",
123
  "vision_use_head": false
124
  }
125
- }
 
2
  "architectures": [
3
  "Gemma3OmniForConditionalGeneration"
4
  ],
5
+ "auto_map": {
6
+ "AutoProcessor": "processing_gemma3_omni.Gemma3OmniProcessor",
7
+ "AutoFeatureExtractor": "processing_gemma3_omni.Gemma3AudioFeatureExtractor",
8
+ "AutoModel": "modeling_gemma_3_omni.Gemma3OmniForConditionalGeneration",
9
+ "AutoModelForCausalLM": "modeling_gemma3_omni.Gemma3OmniForConditionalGeneration",
10
+ "AutoConfig": "configuration_gemma3_omni.Gemma3OmniConfig"
11
+ },
12
  "boi_token_index": 255999,
13
  "eoi_token_index": 256000,
14
  "eos_token_id": [
 
128
  "torch_dtype": "float32",
129
  "vision_use_head": false
130
  }
131
+ }
configuration_gemma3_omni.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union, Dict, Any
2
+
3
+ from transformers import Gemma3TextConfig, SiglipVisionConfig, PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ class Gemma3OmniConfig(PretrainedConfig):
10
+ model_type = "gemma3omni"
11
+ attribute_map = {
12
+ "image_token_id": "image_token_index",
13
+ "audio_token_id": "audio_token_index",
14
+ "boi_token_id": "boi_token_index",
15
+ "eoi_token_id": "eoi_token_index",
16
+ }
17
+ sub_configs = {
18
+ "text_config": Gemma3TextConfig,
19
+ "vision_config": SiglipVisionConfig,
20
+ }
21
+
22
+ def __init__(
23
+ self,
24
+ text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None,
25
+ vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None,
26
+ mm_tokens_per_image: int = 256,
27
+ boi_token_index: int = 255_999,
28
+ eoi_token_index: int = 256_000,
29
+ image_token_index: int = 262_144,
30
+ audio_token_index: int = 262_151,
31
+ initializer_range: float = 0.02,
32
+ **kwargs,
33
+ ):
34
+ if text_config is None:
35
+ text_config = Gemma3TextConfig()
36
+ logger.info("text_config is None, using default Gemma3TextConfig text config.")
37
+ elif isinstance(text_config, dict):
38
+ text_config = Gemma3TextConfig(**text_config)
39
+
40
+ if isinstance(vision_config, dict):
41
+ vision_config = SiglipVisionConfig(**vision_config)
42
+ elif vision_config is None:
43
+ vision_config = SiglipVisionConfig()
44
+ logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
45
+
46
+ self.text_config = text_config
47
+ self.vision_config = vision_config
48
+ self.mm_tokens_per_image = mm_tokens_per_image
49
+ self.boi_token_index = boi_token_index
50
+ self.eoi_token_index = eoi_token_index
51
+ self.image_token_index = image_token_index
52
+ self.audio_token_index = audio_token_index
53
+ self.initializer_range = initializer_range
54
+
55
+ super().__init__(**kwargs)
modeling_gemma3_omni.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import annotations
3
+
4
+ from typing import List, Optional, Tuple, Union, Callable
5
+
6
+ from transformers import (
7
+ AutoModel,
8
+ Cache,
9
+ PreTrainedModel,
10
+ PretrainedConfig, )
11
+ from transformers.generation import GenerationMixin
12
+ from transformers.masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
13
+ from transformers.models.gemma3.modeling_gemma3 import (
14
+ Gemma3CausalLMOutputWithPast,
15
+ Gemma3RMSNorm, Gemma3PreTrainedModel, Gemma3ModelOutputWithPast,
16
+ )
17
+ from transformers.utils import is_torchdynamo_compiling, logging, is_torch_flex_attn_available
18
+
19
+ try:
20
+ from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
21
+ except:
22
+ LigerFusedLinearCrossEntropyLoss = None
23
+
24
+ from .configuration_gemma3_omni import Gemma3OmniConfig
25
+ from .speech_conformer_encoder import ConformerEncoder
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ if is_torch_flex_attn_available():
30
+ from torch.nn.attention.flex_attention import BlockMask
31
+
32
+
33
+ class Gemma3AudioProjectorConfig(PretrainedConfig):
34
+ model_type = "gemma3_audio"
35
+
36
+ def __init__(
37
+ self,
38
+ hidden_size: int = 1024,
39
+ num_hidden_layers: int = 24,
40
+ sample_rate: int = 16_000,
41
+ n_mels: int = 80,
42
+ image_token_index: int = 0, # This seems unused for audio projector, maybe a copy-paste?
43
+ # Added Mel spectrogram specific parameters
44
+ n_fft: int = 400, # Typical for 25ms window at 16kHz
45
+ hop_length: int = 160, # Typical for 10ms hop at 16kHz
46
+ **kwargs,
47
+ ):
48
+ super().__init__(**kwargs)
49
+ self.hidden_size = hidden_size
50
+ self.num_hidden_layers = num_hidden_layers
51
+ self.sample_rate = sample_rate
52
+ self.n_mels = n_mels
53
+ self.image_token_index = image_token_index
54
+ self.n_fft = n_fft
55
+ self.hop_length = hop_length
56
+
57
+
58
+ import torch
59
+ from torch import nn
60
+
61
+
62
+ class LayerWiseWeightedSum(nn.Module):
63
+ def __init__(self, num_layers: int, learnable: bool = True):
64
+ super().__init__()
65
+ self.num_layers = num_layers
66
+ if learnable:
67
+ self.scalar = nn.Parameter(torch.zeros(num_layers))
68
+ else:
69
+ self.register_buffer("scalar", torch.zeros(num_layers))
70
+
71
+ def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
72
+ assert len(hidden_states) == self.num_layers
73
+ norm_w = torch.softmax(self.scalar, dim=0).view(-1, 1, 1, 1)
74
+ stacked = torch.stack(hidden_states, dim=0)
75
+ return (norm_w * stacked).sum(dim=0)
76
+
77
+
78
+ class Gemma3AudioProjector(PreTrainedModel):
79
+ """Conformer-based audio encoder → project to LM hidden-dim."""
80
+
81
+ config_class = Gemma3AudioProjectorConfig
82
+ base_model_prefix = "audio_projector"
83
+
84
+ def __init__(self, config: Gemma3AudioProjectorConfig):
85
+ super().__init__(config)
86
+ encoder_config = {
87
+ "activation": "swish",
88
+ "activation_checkpointing": "",
89
+ "attention_dim": 1024,
90
+ "attention_heads": 16,
91
+ "batch_norm": False,
92
+ "bias_in_glu": True,
93
+ "causal": True,
94
+ "chunk_size": -1,
95
+ "conv_activation": "swish",
96
+ "conv_glu_type": "swish",
97
+ "depthwise_multiplier": 1,
98
+ "depthwise_seperable_out_channel": 1024,
99
+ "dropout_rate": 0.0,
100
+ "encoder_embedding_config": {
101
+ "input_size": config.n_mels # This is feat_in for NemoConvSubsampling
102
+ },
103
+ "ext_pw_kernel_size": 1,
104
+ "ext_pw_out_channel": 1024,
105
+ "input_layer": "nemo_conv",
106
+ "input_size": config.n_mels, # Also feat_in for NemoConvSubsampling, consistency
107
+ "kernel_size": 3,
108
+ "left_chunk": 18,
109
+ "linear_units": 1536,
110
+ "nemo_conv_settings": {
111
+ "conv_channels": 1024,
112
+ },
113
+ "num_blocks": 24,
114
+ "relative_attention_bias_args": {
115
+ "t5_bias_max_distance": 500,
116
+ "type": "t5"
117
+ },
118
+ "time_reduction": 8
119
+ }
120
+ self.encoder = ConformerEncoder(**encoder_config)
121
+ self.layer_weighter = LayerWiseWeightedSum(
122
+ num_layers=encoder_config["num_blocks"]
123
+ )
124
+ self.proj = nn.Linear(encoder_config['attention_dim'], config.hidden_size, bias=False)
125
+
126
+ def forward(self, mel: torch.Tensor, mel_mask: torch.Tensor):
127
+ mel = mel.squeeze(1) # (B, T, 80)
128
+ mel_mask = mel_mask.squeeze(1) # (B, L)
129
+
130
+ if mel_mask.size(1) != mel.size(1):
131
+ mel_mask = mel_mask[..., : mel.size(1)]
132
+
133
+ _, out_mask, hidden_list = self.encoder(mel, mel_mask)
134
+ hidden_sum = self.layer_weighter(hidden_list)
135
+ hidden = self.proj(hidden_sum)
136
+ return hidden, out_mask
137
+
138
+
139
+ class Gemma3VisionProjector(nn.Module):
140
+ def __init__(self, config):
141
+ super().__init__()
142
+ self.mm_input_projection_weight = nn.Parameter(
143
+ torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
144
+ )
145
+ self.mm_soft_emb_norm = Gemma3RMSNorm(
146
+ config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
147
+ )
148
+ self.patches_per_image = config.vision_config.image_size // config.vision_config.patch_size
149
+ self.tokens_per_side = int(config.mm_tokens_per_image ** 0.5)
150
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
151
+ self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
152
+
153
+ def forward(self, vision_outputs: torch.Tensor):
154
+ b, _, seq_len = vision_outputs.shape
155
+ x = vision_outputs.transpose(1, 2).reshape(
156
+ b, seq_len, self.patches_per_image, self.patches_per_image
157
+ )
158
+ x = self.avg_pool(x).flatten(2).transpose(1, 2)
159
+ x = self.mm_soft_emb_norm(x)
160
+ return torch.matmul(x, self.mm_input_projection_weight).type_as(vision_outputs)
161
+
162
+
163
+ def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
164
+ """
165
+ This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
166
+ not start and end indices.
167
+ """
168
+ # Do not return an additional mask in this case
169
+ if token_type_ids is None:
170
+ return None
171
+
172
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
173
+ # If it's 1, we need to unmask it
174
+ return token_type_ids[batch_idx, kv_idx] == 1
175
+
176
+ return inner_mask
177
+
178
+
179
+ class Gemma3OmniModel(Gemma3PreTrainedModel):
180
+ config_class = Gemma3OmniConfig
181
+
182
+ def __init__(self, config):
183
+ super().__init__(config)
184
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
185
+ self.multi_modal_projector = Gemma3VisionProjector(config)
186
+ self.audio_projector = Gemma3AudioProjector(
187
+ Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size)
188
+ )
189
+ self.vocab_size = config.text_config.vocab_size
190
+
191
+ language_model = AutoModel.from_config(config=config.text_config)
192
+ self.language_model = language_model
193
+
194
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
195
+ self.post_init()
196
+
197
+ def get_input_embeddings(self):
198
+ return self.language_model.embed_tokens
199
+
200
+ def forward(
201
+ self,
202
+ input_ids: torch.LongTensor = None,
203
+ pixel_values: torch.FloatTensor = None,
204
+ input_audio_embeds: Optional[torch.FloatTensor] = None,
205
+ audio_attention_mask: Optional[torch.LongTensor] = None,
206
+ attention_mask: Optional[torch.Tensor] = None,
207
+ position_ids: Optional[torch.LongTensor] = None,
208
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
209
+ token_type_ids: Optional[torch.LongTensor] = None,
210
+ cache_position: Optional[torch.LongTensor] = None,
211
+ inputs_embeds: Optional[torch.FloatTensor] = None,
212
+ labels: Optional[torch.LongTensor] = None,
213
+ use_cache: Optional[bool] = None,
214
+ output_attentions: Optional[bool] = None,
215
+ output_hidden_states: Optional[bool] = None,
216
+ return_dict: Optional[bool] = None,
217
+ **lm_kwargs,
218
+ ) -> Union[Tuple, Gemma3ModelOutputWithPast]:
219
+ if (input_ids is None) ^ (inputs_embeds is not None):
220
+ print("input_ids:", input_ids, "inputs_embeds:", inputs_embeds)
221
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
222
+
223
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
224
+ output_hidden_states = (
225
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
226
+ )
227
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
228
+
229
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
230
+ if input_ids is not None and self.config.image_token_id >= self.vocab_size:
231
+ special_image_mask = input_ids == self.config.image_token_id
232
+ llm_input_ids = input_ids.clone()
233
+ llm_input_ids[special_image_mask] = 0
234
+ else:
235
+ llm_input_ids = input_ids
236
+
237
+ if inputs_embeds is None:
238
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
239
+
240
+ if cache_position is None:
241
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
242
+ cache_position = torch.arange(
243
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
244
+ )
245
+
246
+ if pixel_values is not None and past_key_values is None:
247
+ image_features = self.get_image_features(pixel_values)
248
+
249
+ if input_ids is None:
250
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
251
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
252
+ )
253
+ else:
254
+ special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
255
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
256
+
257
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
258
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
259
+ raise ValueError(
260
+ f"Number of images does not match number of special image tokens in the input text. "
261
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
262
+ "tokens from image embeddings."
263
+ )
264
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
265
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
266
+
267
+ if input_audio_embeds is not None and past_key_values is None:
268
+ audio_features, audio_feat_mask = self.audio_projector(
269
+ input_audio_embeds, audio_attention_mask
270
+ )
271
+ if input_ids is None:
272
+ special_audio_mask = (
273
+ inputs_embeds
274
+ == self.get_input_embeddings()(
275
+ torch.tensor(
276
+ self.config.audio_token_index,
277
+ dtype=torch.long,
278
+ device=inputs_embeds.device,
279
+ )
280
+ )
281
+ )
282
+ else:
283
+ special_audio_mask = (
284
+ input_ids == self.config.audio_token_index
285
+ ).unsqueeze(-1) # [B, L, 1]
286
+ special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(
287
+ inputs_embeds.device
288
+ )
289
+ if (
290
+ not is_torchdynamo_compiling()
291
+ and inputs_embeds[special_audio_mask].numel() != audio_features.numel()
292
+ ):
293
+ audio_tokens_in_text = special_audio_mask.sum(dim=1).sum(dim=0)[0]
294
+ raise ValueError(
295
+ f"Number of audio tokens in the text ({audio_tokens_in_text}) "
296
+ f"≠ number of tokens from audio embeddings "
297
+ f"({audio_features.shape[0] * audio_features.shape[1]})."
298
+ )
299
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
300
+ audio_features = audio_features.reshape(-1)
301
+ inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
302
+
303
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
304
+ # Prepare mask arguments
305
+ mask_kwargs = {
306
+ "config": self.config.get_text_config(),
307
+ "input_embeds": inputs_embeds,
308
+ "attention_mask": attention_mask,
309
+ "cache_position": cache_position,
310
+ "past_key_values": past_key_values,
311
+ }
312
+ if token_type_ids is not None and inputs_embeds.shape[1] != 1:
313
+ mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
314
+ token_type_ids.to(cache_position.device)
315
+ )
316
+
317
+ # Create the masks
318
+ causal_mask_mapping = {
319
+ "full_attention": create_causal_mask(**mask_kwargs),
320
+ "sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
321
+ }
322
+
323
+ outputs = self.language_model(
324
+ attention_mask=causal_mask_mapping,
325
+ position_ids=position_ids,
326
+ past_key_values=past_key_values,
327
+ inputs_embeds=inputs_embeds,
328
+ use_cache=use_cache,
329
+ output_attentions=output_attentions,
330
+ output_hidden_states=output_hidden_states,
331
+ return_dict=True,
332
+ cache_position=cache_position,
333
+ **lm_kwargs,
334
+ )
335
+
336
+ return Gemma3ModelOutputWithPast(
337
+ last_hidden_state=outputs.last_hidden_state,
338
+ past_key_values=outputs.past_key_values if use_cache else None,
339
+ hidden_states=outputs.hidden_states,
340
+ attentions=outputs.attentions,
341
+ image_hidden_states=image_features if pixel_values is not None else None,
342
+ )
343
+
344
+
345
+ class Gemma3OmniForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
346
+ config_class = Gemma3OmniConfig
347
+ """Gemma-3 Omni:vision + audio + text causal LM."""
348
+ _checkpoint_conversion_mapping = {
349
+ "^language_model.model": "model.language_model",
350
+ "^vision_tower": "model.vision_tower",
351
+ "^multi_modal_projector": "model.multi_modal_projector",
352
+ "^language_model.lm_head": "lm_head",
353
+ }
354
+ _tied_weights_keys = ["lm_head.weight"]
355
+
356
+ def __init__(self, config):
357
+ super().__init__(config)
358
+ self.model = Gemma3OmniModel(config)
359
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
360
+ self.post_init()
361
+
362
+ def get_input_embeddings(self):
363
+ return self.model.language_model.embed_tokens
364
+
365
+ def forward(
366
+ self,
367
+ input_ids: torch.LongTensor = None,
368
+ pixel_values: torch.FloatTensor = None,
369
+ input_audio_embeds: Optional[torch.FloatTensor] = None,
370
+ audio_attention_mask: Optional[torch.LongTensor] = None,
371
+ attention_mask: Optional[torch.Tensor] = None,
372
+ position_ids: Optional[torch.LongTensor] = None,
373
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
374
+ token_type_ids: Optional[torch.LongTensor] = None,
375
+ cache_position: Optional[torch.LongTensor] = None,
376
+ inputs_embeds: Optional[torch.FloatTensor] = None,
377
+ labels: Optional[torch.LongTensor] = None,
378
+ use_cache: Optional[bool] = None,
379
+ output_attentions: Optional[bool] = None,
380
+ output_hidden_states: Optional[bool] = None,
381
+ return_dict: Optional[bool] = None,
382
+ logits_to_keep: Union[int, torch.Tensor] = 0,
383
+ **lm_kwargs,
384
+ ) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
385
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
386
+ output_hidden_states = (
387
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
388
+ )
389
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
390
+
391
+ outputs = self.model(
392
+ input_ids=input_ids,
393
+ pixel_values=pixel_values,
394
+ input_audio_embeds=input_audio_embeds,
395
+ audio_attention_mask=audio_attention_mask,
396
+ token_type_ids=token_type_ids,
397
+ attention_mask=attention_mask,
398
+ position_ids=position_ids,
399
+ past_key_values=past_key_values,
400
+ inputs_embeds=inputs_embeds,
401
+ use_cache=use_cache,
402
+ labels=labels,
403
+ output_attentions=output_attentions,
404
+ output_hidden_states=output_hidden_states,
405
+ return_dict=return_dict,
406
+ cache_position=cache_position,
407
+ **lm_kwargs,
408
+ )
409
+
410
+ hidden_states = outputs[0]
411
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
412
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
413
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
414
+
415
+ loss = None
416
+ if labels is not None:
417
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
418
+ logits = logits.float()
419
+ shift_logits = logits[..., :-1, :]
420
+ shift_labels = labels[..., 1:]
421
+ if attention_mask is not None:
422
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
423
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
424
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device)
425
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
426
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
427
+ else:
428
+ shift_logits = shift_logits.contiguous()
429
+ shift_labels = shift_labels.contiguous()
430
+ # Flatten the tokens
431
+
432
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
433
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
434
+
435
+ if LigerFusedLinearCrossEntropyLoss is not None:
436
+ loss_fct = LigerFusedLinearCrossEntropyLoss()
437
+ else:
438
+ loss_fct = nn.CrossEntropyLoss()
439
+ loss = loss_fct(flat_logits, flat_labels)
440
+
441
+ if not return_dict:
442
+ output = (logits,) + outputs[1:]
443
+ return (loss,) + output if loss is not None else output
444
+
445
+ return Gemma3CausalLMOutputWithPast(
446
+ loss=loss,
447
+ logits=logits,
448
+ past_key_values=outputs.past_key_values,
449
+ hidden_states=outputs.hidden_states,
450
+ attentions=outputs.attentions,
451
+ image_hidden_states=outputs.image_hidden_states,
452
+ )
453
+
454
+
455
+ __all__ = [
456
+ "Gemma3AudioProjectorConfig",
457
+ "Gemma3AudioProjector",
458
+ "Gemma3VisionProjector",
459
+ "Gemma3OmniModel",
460
+ "Gemma3OmniForConditionalGeneration",
461
+ ]
preprocessor_config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_gemma3_omni.Gemma3OmniProcessor",
4
+ "AutoFeatureExtractor": "processing_gemma3_omni.Gemma3AudioFeatureExtractor"
5
+ },
6
+ "do_convert_rgb": null,
7
+ "do_normalize": true,
8
+ "do_pan_and_scan": null,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.5,
13
+ 0.5,
14
+ 0.5
15
+ ],
16
+ "image_processor_type": "Gemma3ImageProcessor",
17
+ "processor_class": "Gemma3Processor",
18
+ "image_seq_length": 256,
19
+ "image_std": [
20
+ 0.5,
21
+ 0.5,
22
+ 0.5
23
+ ],
24
+ "pan_and_scan_max_num_crops": null,
25
+ "pan_and_scan_min_crop_size": null,
26
+ "pan_and_scan_min_ratio_to_activate": null,
27
+ "resample": 2,
28
+ "rescale_factor": 0.00392156862745098,
29
+ "size": {
30
+ "height": 896,
31
+ "width": 896
32
+ },
33
+ "compression_rate": 4,
34
+ "feat_stride": 4,
35
+ "feature_extractor_type": "Gemma3AudioFeatureExtractor",
36
+ "feature_size": 80,
37
+ "hop_length": 160,
38
+ "n_fft": 512,
39
+ "padding_side": "right",
40
+ "padding_value": 0.0,
41
+ "processor_class": "Gemma3OmniProcessor",
42
+ "qformer_rate": 2,
43
+ "return_attention_mask": true,
44
+ "sampling_rate": 16000,
45
+ "win_length": 400
46
+ }
processing_gemma3_omni.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Union, Dict, Any, Tuple
2
+
3
+ import numpy as np
4
+ import scipy.signal
5
+ import torch
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
8
+ from transformers.feature_extraction_utils import BatchFeature
9
+ from transformers.image_utils import make_nested_list_of_images
10
+ from transformers.processing_utils import ProcessorMixin
11
+ from transformers.utils import TensorType, logging
12
+
13
+ DEFAULT_SPECIAL_TOKENS = {
14
+ "bos_token": "<bos>",
15
+ "eos_token": "<eos>",
16
+ "pad_token": "<pad>",
17
+ "unk_token": "<unk>",
18
+ "boi_token": "<start_of_image>",
19
+ "eoi_token": "<end_of_image>",
20
+ "image_token": "<image_soft_token>",
21
+ "boa_token": "<start_of_audio>",
22
+ "eoa_token": "<end_of_audio>",
23
+ "audio_token": "<audio_soft_token>",
24
+ }
25
+ DEFAULT_SAMPLING_RATE = 16000
26
+ DEFAULT_N_FFT = 512
27
+ DEFAULT_WIN_LENGTH = 400
28
+ DEFAULT_HOP_LENGTH = 160
29
+ DEFAULT_N_MELS = 80
30
+ DEFAULT_COMPRESSION_RATE = 4
31
+ DEFAULT_QFORMER_RATE = 8
32
+ DEFAULT_FEAT_STRIDE = 4
33
+ DEFAULT_IMAGE_SEQ_LENGTH = 256
34
+ DEFAULT_MAX_LENGTH = 16384
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ def compute_audio_token_count(
40
+ mel_frame_count: int,
41
+ *,
42
+ feat_stride: int = DEFAULT_FEAT_STRIDE,
43
+ compression_rate: int = DEFAULT_COMPRESSION_RATE,
44
+ qformer_rate: int = DEFAULT_QFORMER_RATE,
45
+ ) -> int:
46
+ audio_frames = mel_frame_count * feat_stride
47
+ audio_frames = (audio_frames + compression_rate - 1) // compression_rate
48
+ audio_frames = (audio_frames + qformer_rate - 1) // qformer_rate
49
+ return audio_frames
50
+
51
+
52
+ def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
53
+ bank_width = int(n_fft // 2 + 1)
54
+ if fmax is None:
55
+ fmax = sample_rate / 2
56
+ if fmin is None:
57
+ fmin = 0
58
+
59
+ def mel(f):
60
+ return 1127.0 * np.log(1.0 + f / 700.0)
61
+
62
+ def bin2mel(fft_bin):
63
+ return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
64
+
65
+ def f2bin(f):
66
+ return int((f * n_fft / sample_rate) + 0.5)
67
+
68
+ klo = f2bin(fmin) + 1
69
+ khi = f2bin(fmax)
70
+ khi = max(khi, klo)
71
+ mlo = mel(fmin)
72
+ mhi = mel(fmax)
73
+ m_centers = np.linspace(mlo, mhi, n_mels + 2)
74
+ ms = (mhi - mlo) / (n_mels + 1)
75
+ matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
76
+ for m in range(n_mels):
77
+ left = m_centers[m]
78
+ center = m_centers[m + 1]
79
+ right = m_centers[m + 2]
80
+ for fft_bin in range(klo, khi):
81
+ mbin = bin2mel(fft_bin)
82
+ if left < mbin < right:
83
+ matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
84
+ return matrix
85
+
86
+
87
+ class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
88
+ model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
89
+
90
+ def __init__(
91
+ self,
92
+ audio_compression_rate: int = DEFAULT_COMPRESSION_RATE,
93
+ audio_downsample_rate: int = DEFAULT_QFORMER_RATE,
94
+ audio_feat_stride: int = DEFAULT_FEAT_STRIDE,
95
+ feature_size: int = DEFAULT_N_MELS,
96
+ sampling_rate: int = DEFAULT_SAMPLING_RATE,
97
+ padding_value: float = 0.0,
98
+ eightk_method: str = "fillzero",
99
+ **kwargs,
100
+ ):
101
+ super().__init__(
102
+ feature_size=kwargs.pop("feature_size", feature_size),
103
+ sampling_rate=kwargs.pop("sampling_rate", sampling_rate),
104
+ padding_value=kwargs.pop("padding_value", padding_value),
105
+ **kwargs,
106
+ )
107
+ self.compression_rate = audio_compression_rate
108
+ self.qformer_compression_rate = audio_downsample_rate
109
+ self.feat_stride = audio_feat_stride
110
+ self._eightk_method = eightk_method
111
+ self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
112
+ self._hamming400 = np.hamming(400)
113
+ self._hamming200 = np.hamming(200)
114
+
115
+ def __call__(
116
+ self,
117
+ audios: List[Tuple[np.ndarray, int]],
118
+ return_tensors: Optional[Union[str, TensorType]] = None,
119
+ ):
120
+ returned_input_audio_embeds = []
121
+ returned_audio_embed_sizes = []
122
+ audio_frames_list = []
123
+ for audio_data, sample_rate in audios:
124
+ if isinstance(audio_data, list):
125
+ audio_data = np.array(audio_data, dtype=np.float32)
126
+ if not isinstance(audio_data, np.ndarray):
127
+ raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}")
128
+ audio_embeds_np = self._extract_features(audio_data, sample_rate)
129
+ num_mel_frames = audio_embeds_np.shape[0]
130
+ current_audio_frames = num_mel_frames * self.feat_stride
131
+ audio_embed_size = self._compute_audio_embed_size(current_audio_frames)
132
+ returned_input_audio_embeds.append(torch.from_numpy(audio_embeds_np))
133
+ returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
134
+ audio_frames_list.append(current_audio_frames)
135
+ padded_input_audio_embeds = pad_sequence(
136
+ returned_input_audio_embeds, batch_first=True, padding_value=self.padding_value
137
+ )
138
+ stacked_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
139
+ tensor_audio_frames = torch.tensor(audio_frames_list, dtype=torch.long)
140
+ max_audio_frames = tensor_audio_frames.max().item() if tensor_audio_frames.numel() > 0 else 0
141
+ if max_audio_frames > 0 and len(audios) > 1:
142
+ audio_attention_mask = (
143
+ torch.arange(0, max_audio_frames, device=tensor_audio_frames.device).unsqueeze(0)
144
+ < tensor_audio_frames.unsqueeze(1)
145
+ )
146
+ elif max_audio_frames > 0:
147
+ audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool, device=tensor_audio_frames.device)
148
+ else:
149
+ audio_attention_mask = None
150
+ data = {
151
+ "input_audio_embeds": padded_input_audio_embeds,
152
+ "audio_embed_sizes": stacked_audio_embed_sizes,
153
+ }
154
+ if audio_attention_mask is not None:
155
+ data["audio_attention_mask"] = audio_attention_mask
156
+ return BatchFeature(data=data, tensor_type=return_tensors)
157
+
158
+ def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray:
159
+ if wav.ndim > 1:
160
+ wav = np.squeeze(wav)
161
+ if len(wav.shape) == 2:
162
+ wav = wav.mean(axis=1).astype(np.float32)
163
+ wav = wav.astype(np.float32)
164
+ current_fs = fs
165
+ if current_fs > self.sampling_rate:
166
+ wav = scipy.signal.resample_poly(wav, self.sampling_rate, current_fs)
167
+ current_fs = self.sampling_rate
168
+ elif 8000 < current_fs < self.sampling_rate:
169
+ wav = scipy.signal.resample_poly(wav, 8000, current_fs)
170
+ current_fs = 8000
171
+ elif current_fs < 8000 and current_fs > 0:
172
+ wav = scipy.signal.resample_poly(wav, 8000, current_fs)
173
+ current_fs = 8000
174
+ elif current_fs <= 0:
175
+ raise RuntimeError(f"Unsupported sample rate {current_fs}")
176
+ if current_fs == 8000 and self._eightk_method == "resample":
177
+ wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000)
178
+ current_fs = self.sampling_rate
179
+ elif current_fs != self.sampling_rate:
180
+ raise RuntimeError(
181
+ f"Audio sample rate {current_fs} not supported. Expected {self.sampling_rate} or 8000 for 8k methods.")
182
+ preemphasis_coeff = 0.97
183
+ if current_fs == 8000:
184
+ n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200
185
+ else:
186
+ n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400
187
+ if len(wav) < win_length:
188
+ wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,))
189
+ num_frames = (wav.shape[0] - win_length) // hop_length + 1
190
+ if num_frames <= 0:
191
+ return np.zeros((0, n_fft // 2 + 1), dtype=np.float32)
192
+ y_frames = np.array(
193
+ [wav[i * hop_length: i * hop_length + win_length] for i in range(num_frames)],
194
+ dtype=np.float32,
195
+ )
196
+ _y_frames_rolled = np.roll(y_frames, 1, axis=1)
197
+ _y_frames_rolled[:, 0] = _y_frames_rolled[:, 1]
198
+ y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0
199
+ S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64)
200
+ if current_fs == 8000 and self._eightk_method == "fillzero":
201
+ target_bins = (512 // 2) + 1
202
+ S_core = S[:, :-1]
203
+ padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype)
204
+ S = np.concatenate((S_core, padarray), axis=1)
205
+ spec = np.abs(S).astype(np.float32)
206
+ return spec
207
+
208
+ def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray:
209
+ spec = self._extract_spectrogram(wav, fs)
210
+ if spec.shape[0] == 0:
211
+ return np.zeros((0, self.feature_size), dtype=np.float32)
212
+ spec_power = spec ** 2
213
+ fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
214
+ log_fbank = np.log(fbank_power).astype(np.float32)
215
+ return log_fbank
216
+
217
+ def _compute_audio_embed_size(self, audio_frames: int) -> int:
218
+ integer = audio_frames // self.compression_rate
219
+ remainder = audio_frames % self.compression_rate
220
+ result = integer if remainder == 0 else integer + 1
221
+ integer = result // self.qformer_compression_rate
222
+ remainder = result % self.qformer_compression_rate
223
+ result = integer if remainder == 0 else integer + 1
224
+ return result
225
+
226
+
227
+ class Gemma3OmniProcessor(ProcessorMixin):
228
+ attributes = ["image_processor", "audio_processor", "tokenizer"]
229
+ image_processor_class = "AutoImageProcessor"
230
+ audio_processor_class = "AutoFeatureExtractor"
231
+ tokenizer_class = "AutoTokenizer"
232
+
233
+ def __init__(
234
+ self,
235
+ image_processor=None,
236
+ audio_processor=None,
237
+ tokenizer=None,
238
+ special_tokens: Optional[Dict[str, str]] = None,
239
+ image_seq_length: int = DEFAULT_IMAGE_SEQ_LENGTH,
240
+ prompt_audio_compression_rate: int = DEFAULT_COMPRESSION_RATE,
241
+ prompt_audio_qformer_rate: int = DEFAULT_QFORMER_RATE,
242
+ audio_placeholder_token: str = "<|audio_placeholder|>",
243
+ **kwargs,
244
+ ):
245
+ super().__init__(
246
+ image_processor=image_processor,
247
+ audio_processor=audio_processor,
248
+ tokenizer=tokenizer,
249
+ **kwargs,
250
+ )
251
+ self.special_tokens = dict(DEFAULT_SPECIAL_TOKENS)
252
+ if special_tokens is not None:
253
+ self.special_tokens.update(special_tokens)
254
+ if tokenizer is not None:
255
+ for key in self.special_tokens:
256
+ val = getattr(tokenizer, key, None)
257
+ if isinstance(val, str):
258
+ self.special_tokens[key] = val
259
+ self.image_token = self.special_tokens["image_token"]
260
+ self.audio_token = self.special_tokens["audio_token"]
261
+ self.boi_token = self.special_tokens["boi_token"]
262
+ self.eoi_token = self.special_tokens["eoi_token"]
263
+ self.boa_token = self.special_tokens["boa_token"]
264
+ self.eoa_token = self.special_tokens["eoa_token"]
265
+ self.image_seq_length = image_seq_length
266
+ self.full_image_sequence = f"{self.boi_token}{''.join([self.image_token] * self.image_seq_length)}{self.eoi_token}"
267
+ self.prompt_audio_compression_rate = prompt_audio_compression_rate
268
+ self.prompt_audio_qformer_rate = prompt_audio_qformer_rate
269
+ self.audio_placeholder_token = audio_placeholder_token
270
+ if self.tokenizer is not None:
271
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
272
+ self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token)
273
+ else:
274
+ self.image_token_id = None
275
+ self.audio_token_id = None
276
+
277
+ def compute_audio_token_count(self, mel_frame_count: int) -> int:
278
+ stride = getattr(self.audio_processor, "feat_stride", DEFAULT_FEAT_STRIDE)
279
+ return compute_audio_token_count(
280
+ mel_frame_count,
281
+ feat_stride=stride,
282
+ compression_rate=self.prompt_audio_compression_rate,
283
+ qformer_rate=self.prompt_audio_qformer_rate,
284
+ )
285
+
286
+ def apply_chat_template(
287
+ self,
288
+ messages,
289
+ add_generation_prompt: bool = True,
290
+ tokenize: bool = False,
291
+ **kwargs
292
+ ) -> Union[str, Dict[str, Any]]:
293
+ prompt = ""
294
+ if isinstance(messages, dict) and "messages" in messages:
295
+ if "audios" in messages:
296
+ audios = messages["audios"]
297
+ if "audio" in messages:
298
+ audios = [messages["audio"]]
299
+ if "images" in messages:
300
+ images = messages["images"]
301
+ if "image" in messages:
302
+ images = [messages["image"]]
303
+ messages = messages["messages"]
304
+
305
+ for msg in messages:
306
+ role = msg.get("role", "")
307
+ prompt += f"<start_of_turn>{role}\n"
308
+ contents = msg.get("content", [])
309
+ if not isinstance(contents, list):
310
+ contents = [contents]
311
+
312
+ for c in contents:
313
+ if isinstance(c, dict):
314
+ ctype = c.get("type")
315
+ if ctype == "image":
316
+ idx = c.get("index")
317
+ img_data = None
318
+ if idx is not None and isinstance(idx, int):
319
+ img_data = images[idx]
320
+ elif "image" in c:
321
+ img_data = c["image"]
322
+ if img_data is None:
323
+ logger.warning("No image data found for image content: %s", c)
324
+ prompt += self.full_image_sequence
325
+ continue
326
+
327
+ if ctype == "audio":
328
+ idx = c.get("index")
329
+ aud_data = None
330
+ if idx is not None and isinstance(idx, int):
331
+ aud_data = audios[idx]["array"]
332
+ sr = audios[idx].get("sampling_rate",
333
+ self.audio_processor.sampling_rate if self.audio_processor else DEFAULT_SAMPLING_RATE)
334
+ elif "audio" in c:
335
+ aud_data = c["audio"]
336
+ sr = c.get("sampling_rate",
337
+ self.audio_processor.sampling_rate if self.audio_processor else DEFAULT_SAMPLING_RATE)
338
+ if aud_data is None:
339
+ logger.warning("No audio data found for audio content: %s", c)
340
+
341
+ n_audio_tokens = 0
342
+ if self.audio_processor:
343
+ features = self.audio_processor(audios=[(aud_data, sr)], return_tensors=None)
344
+ mel_frame_count = features["input_audio_embeds"].shape[1]
345
+ n_audio_tokens = self.compute_audio_token_count(mel_frame_count)
346
+ prompt += (
347
+ self.boa_token +
348
+ (self.audio_token * n_audio_tokens) +
349
+ self.eoa_token
350
+ )
351
+ continue
352
+
353
+ if ctype == "text" and "text" in c:
354
+ prompt += str(c["text"])
355
+ continue
356
+ continue
357
+
358
+ elif isinstance(c, str):
359
+ prompt += c
360
+ continue
361
+ else:
362
+ logger.warning("Unknown content type in message: %s", c)
363
+ continue
364
+
365
+ prompt += "<end_of_turn>\n"
366
+
367
+ if add_generation_prompt:
368
+ prompt += "<start_of_turn>model\n"
369
+
370
+ if tokenize and self.tokenizer is not None:
371
+ safe_kwargs = {}
372
+ allowed_keys = {"return_tensors", "padding", "truncation", "max_length", "add_special_tokens"}
373
+ for k, v in kwargs.items():
374
+ if k in allowed_keys:
375
+ safe_kwargs[k] = v
376
+ return self.tokenizer(prompt, **safe_kwargs)
377
+
378
+ return prompt
379
+
380
+ def __call__(
381
+ self,
382
+ text: Optional[Union[str, List[str]]] = None,
383
+ images: Optional[Union[Any, List[Any]]] = None,
384
+ audios: Optional[Union[Tuple[np.ndarray, int], List[Tuple[np.ndarray, int]]]] = None,
385
+ messages: Optional[List[Dict]] = None,
386
+ add_generation_prompt: bool = True,
387
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
388
+ device: Optional[str] = None,
389
+ **kwargs
390
+ ) -> Dict[str, Any]:
391
+ if messages is not None:
392
+ if isinstance(messages, dict):
393
+ messages = [messages]
394
+ prompt = self.apply_chat_template(
395
+ messages,
396
+ add_generation_prompt=add_generation_prompt,
397
+ tokenize=False,
398
+ )
399
+ audio_inputs = []
400
+ for msg in messages:
401
+ contents = msg.get("content", [])
402
+ if not isinstance(contents, list):
403
+ contents = [contents]
404
+ for c in contents:
405
+ if isinstance(c, dict) and c.get("type") == "audio":
406
+ arr = c["audio"]
407
+ sr = c.get("sampling_rate",
408
+ self.audio_processor.sampling_rate if self.audio_processor else 16000)
409
+ audio_inputs.append((arr, sr))
410
+ audio_features = {}
411
+ if audio_inputs and self.audio_processor is not None:
412
+ audio_features = self.audio_processor(audios=audio_inputs, return_tensors=return_tensors)
413
+ text_features = self.tokenizer(prompt, return_tensors=return_tensors, padding=True, truncation=True,
414
+ max_length=DEFAULT_MAX_LENGTH)
415
+ inputs = {**text_features, **audio_features}
416
+ else:
417
+ if text is None and images is None and audios is None:
418
+ raise ValueError("At least one of text/images/audios/messages must be provided.")
419
+ num_samples = 1
420
+ if isinstance(text, list):
421
+ num_samples = len(text)
422
+ elif images is not None and isinstance(images, list):
423
+ num_samples = len(images)
424
+ elif audios is not None and isinstance(audios, list):
425
+ num_samples = len(audios)
426
+ image_features = {}
427
+ if images is not None and self.image_processor is not None:
428
+ batched_images = make_nested_list_of_images(images)
429
+ img_out = self.image_processor(batched_images, return_tensors=None)
430
+ image_features = img_out.data if isinstance(img_out, BatchFeature) else img_out
431
+ audio_features = {}
432
+ audio_token_counts = None
433
+ if audios is not None and self.audio_processor is not None:
434
+ audio_out = self.audio_processor(audios=audios, return_tensors=None)
435
+ audio_features = audio_out.data
436
+ att_mask = audio_features[self.audio_processor.model_input_names[2]]
437
+ if isinstance(att_mask, torch.Tensor):
438
+ frames_for_embed = att_mask.sum(dim=-1).cpu().tolist()
439
+ else:
440
+ frames_for_embed = np.array(att_mask).sum(axis=-1).tolist()
441
+ audio_token_counts = [self.compute_audio_token_count(mel_frame_count) for mel_frame_count in
442
+ frames_for_embed]
443
+ if text is None:
444
+ text = [""] * num_samples
445
+ elif isinstance(text, str):
446
+ text = [text]
447
+ prompts = []
448
+ for idx in range(num_samples):
449
+ prompt = text[idx]
450
+ has_image = images is not None
451
+ audio_count = audio_token_counts[idx] if audio_token_counts is not None else None
452
+ prompt_str = prompt
453
+ if has_image:
454
+ prompt_str = prompt_str.replace(self.boi_token, self.full_image_sequence)
455
+ if audio_count is not None:
456
+ prompt_str = prompt_str.replace(self.boa_token, self.boa_token + (self.audio_token * audio_count))
457
+ prompts.append(prompt_str)
458
+ text_features = self.tokenizer(prompts, return_tensors=return_tensors, padding=True, truncation=True,
459
+ max_length=DEFAULT_MAX_LENGTH)
460
+ inputs = {**text_features}
461
+ if image_features:
462
+ inputs.update(image_features)
463
+ if audio_features:
464
+ inputs.update(audio_features)
465
+ if device is not None:
466
+ inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
467
+ return inputs
468
+
469
+ @property
470
+ def model_input_names(self) -> List[str]:
471
+ input_names = set()
472
+ if hasattr(self, 'tokenizer') and self.tokenizer is not None:
473
+ tokenizer_inputs = self.tokenizer.model_input_names
474
+ if isinstance(tokenizer_inputs, (list, set)):
475
+ input_names.update(tokenizer_inputs)
476
+ else:
477
+ input_names.add(str(tokenizer_inputs))
478
+ input_names.add("token_type_ids")
479
+ if hasattr(self, 'image_processor') and self.image_processor is not None:
480
+ image_inputs = self.image_processor.model_input_names
481
+ if isinstance(image_inputs, (list, set)):
482
+ input_names.update(image_inputs)
483
+ else:
484
+ input_names.add(str(image_inputs))
485
+ if hasattr(self, 'audio_processor') and self.audio_processor is not None:
486
+ audio_inputs = self.audio_processor.model_input_names
487
+ if isinstance(audio_inputs, (list, set)):
488
+ input_names.update(audio_inputs)
489
+ else:
490
+ input_names.add(str(audio_inputs))
491
+ return list(input_names)
special_tokens_map.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "boi_token": "<start_of_image>",
3
+ "bos_token": {
4
+ "content": "<bos>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ },
10
+ "eoi_token": "<end_of_image>",
11
+ "eos_token": {
12
+ "content": "<eos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "image_token": "<image_soft_token>",
19
+ "boa_token": "<start_of_audio>",
20
+ "eoa_token": "<end_of_audio>",
21
+ "audio_token": "<audio_soft_token>",
22
+ "pad_token": {
23
+ "content": "<pad>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false
28
+ },
29
+ "unk_token": {
30
+ "content": "<unk>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ }
36
+ }
speech_conformer_encoder.py ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88787c7cf85c7d14c8dd2a29cc86f69a1a7d151f306ce00bb54fe7dc35284b0e
3
+ size 33384534
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
3
+ size 4689074
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff