Upload 12 files
Browse files- .gitattributes +1 -0
- __init__.py +0 -0
- chat_template.jinja +1 -0
- config.json +8 -2
- configuration_gemma3_omni.py +55 -0
- modeling_gemma3_omni.py +461 -0
- preprocessor_config.json +46 -0
- processing_gemma3_omni.py +491 -0
- special_tokens_map.json +36 -0
- speech_conformer_encoder.py +0 -0
- tokenizer.json +3 -0
- tokenizer.model +3 -0
- tokenizer_config.json +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
__init__.py
ADDED
File without changes
|
chat_template.jinja
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' %}{% set loop_messages = messages[1:] %}{% else %}{% set first_user_prefix = '' %}{% set loop_messages = messages %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set role = 'model' if message['role'] == 'assistant' else message['role'] %}{{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else '') }}{% if role == 'model' and message.get('metadata') %}{% if message['metadata']['type'] == 'think' %}<think>{% if message['metadata'].get('range') %}<range>{{ message['metadata']['range'] }}</range>{% endif %}{% if message['metadata'].get('content') %}{{ message['metadata']['content'] | trim }}{% endif %}</think>{% elif message['metadata']['type'] == 'direct' %}<direct>{% endif %}{% if message['metadata'].get('function') %}<function>{{ message['metadata']['function'] | join(',') }}</function>{% endif %}{% endif %}{% if message['content'] is string %}{{ message['content'] | trim }}{% elif message['content'] is iterable %}{% for item in message['content'] %}{{ '<start_of_image>' if item['type']=='image' else '<start_of_audio>' if item['type']=='audio' else item['text'] | trim if item['type']=='text' else '' }}{% endfor %}{% else %}{{ raise_exception('Invalid content type') }}{% endif %}{{ '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}
|
config.json
CHANGED
@@ -2,7 +2,13 @@
|
|
2 |
"architectures": [
|
3 |
"Gemma3OmniForConditionalGeneration"
|
4 |
],
|
5 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"boi_token_index": 255999,
|
7 |
"eoi_token_index": 256000,
|
8 |
"eos_token_id": [
|
@@ -122,4 +128,4 @@
|
|
122 |
"torch_dtype": "float32",
|
123 |
"vision_use_head": false
|
124 |
}
|
125 |
-
}
|
|
|
2 |
"architectures": [
|
3 |
"Gemma3OmniForConditionalGeneration"
|
4 |
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoProcessor": "processing_gemma3_omni.Gemma3OmniProcessor",
|
7 |
+
"AutoFeatureExtractor": "processing_gemma3_omni.Gemma3AudioFeatureExtractor",
|
8 |
+
"AutoModel": "modeling_gemma_3_omni.Gemma3OmniForConditionalGeneration",
|
9 |
+
"AutoModelForCausalLM": "modeling_gemma3_omni.Gemma3OmniForConditionalGeneration",
|
10 |
+
"AutoConfig": "configuration_gemma3_omni.Gemma3OmniConfig"
|
11 |
+
},
|
12 |
"boi_token_index": 255999,
|
13 |
"eoi_token_index": 256000,
|
14 |
"eos_token_id": [
|
|
|
128 |
"torch_dtype": "float32",
|
129 |
"vision_use_head": false
|
130 |
}
|
131 |
+
}
|
configuration_gemma3_omni.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union, Dict, Any
|
2 |
+
|
3 |
+
from transformers import Gemma3TextConfig, SiglipVisionConfig, PretrainedConfig
|
4 |
+
from transformers.utils import logging
|
5 |
+
|
6 |
+
logger = logging.get_logger(__name__)
|
7 |
+
|
8 |
+
|
9 |
+
class Gemma3OmniConfig(PretrainedConfig):
|
10 |
+
model_type = "gemma3omni"
|
11 |
+
attribute_map = {
|
12 |
+
"image_token_id": "image_token_index",
|
13 |
+
"audio_token_id": "audio_token_index",
|
14 |
+
"boi_token_id": "boi_token_index",
|
15 |
+
"eoi_token_id": "eoi_token_index",
|
16 |
+
}
|
17 |
+
sub_configs = {
|
18 |
+
"text_config": Gemma3TextConfig,
|
19 |
+
"vision_config": SiglipVisionConfig,
|
20 |
+
}
|
21 |
+
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
text_config: Optional[Union[Gemma3TextConfig, Dict[str, Any]]] = None,
|
25 |
+
vision_config: Optional[Union[SiglipVisionConfig, Dict[str, Any]]] = None,
|
26 |
+
mm_tokens_per_image: int = 256,
|
27 |
+
boi_token_index: int = 255_999,
|
28 |
+
eoi_token_index: int = 256_000,
|
29 |
+
image_token_index: int = 262_144,
|
30 |
+
audio_token_index: int = 262_151,
|
31 |
+
initializer_range: float = 0.02,
|
32 |
+
**kwargs,
|
33 |
+
):
|
34 |
+
if text_config is None:
|
35 |
+
text_config = Gemma3TextConfig()
|
36 |
+
logger.info("text_config is None, using default Gemma3TextConfig text config.")
|
37 |
+
elif isinstance(text_config, dict):
|
38 |
+
text_config = Gemma3TextConfig(**text_config)
|
39 |
+
|
40 |
+
if isinstance(vision_config, dict):
|
41 |
+
vision_config = SiglipVisionConfig(**vision_config)
|
42 |
+
elif vision_config is None:
|
43 |
+
vision_config = SiglipVisionConfig()
|
44 |
+
logger.info("vision_config is None, using default SiglipVisionConfig vision config.")
|
45 |
+
|
46 |
+
self.text_config = text_config
|
47 |
+
self.vision_config = vision_config
|
48 |
+
self.mm_tokens_per_image = mm_tokens_per_image
|
49 |
+
self.boi_token_index = boi_token_index
|
50 |
+
self.eoi_token_index = eoi_token_index
|
51 |
+
self.image_token_index = image_token_index
|
52 |
+
self.audio_token_index = audio_token_index
|
53 |
+
self.initializer_range = initializer_range
|
54 |
+
|
55 |
+
super().__init__(**kwargs)
|
modeling_gemma3_omni.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
from typing import List, Optional, Tuple, Union, Callable
|
5 |
+
|
6 |
+
from transformers import (
|
7 |
+
AutoModel,
|
8 |
+
Cache,
|
9 |
+
PreTrainedModel,
|
10 |
+
PretrainedConfig, )
|
11 |
+
from transformers.generation import GenerationMixin
|
12 |
+
from transformers.masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
|
13 |
+
from transformers.models.gemma3.modeling_gemma3 import (
|
14 |
+
Gemma3CausalLMOutputWithPast,
|
15 |
+
Gemma3RMSNorm, Gemma3PreTrainedModel, Gemma3ModelOutputWithPast,
|
16 |
+
)
|
17 |
+
from transformers.utils import is_torchdynamo_compiling, logging, is_torch_flex_attn_available
|
18 |
+
|
19 |
+
try:
|
20 |
+
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
|
21 |
+
except:
|
22 |
+
LigerFusedLinearCrossEntropyLoss = None
|
23 |
+
|
24 |
+
from .configuration_gemma3_omni import Gemma3OmniConfig
|
25 |
+
from .speech_conformer_encoder import ConformerEncoder
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
if is_torch_flex_attn_available():
|
30 |
+
from torch.nn.attention.flex_attention import BlockMask
|
31 |
+
|
32 |
+
|
33 |
+
class Gemma3AudioProjectorConfig(PretrainedConfig):
|
34 |
+
model_type = "gemma3_audio"
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
hidden_size: int = 1024,
|
39 |
+
num_hidden_layers: int = 24,
|
40 |
+
sample_rate: int = 16_000,
|
41 |
+
n_mels: int = 80,
|
42 |
+
image_token_index: int = 0, # This seems unused for audio projector, maybe a copy-paste?
|
43 |
+
# Added Mel spectrogram specific parameters
|
44 |
+
n_fft: int = 400, # Typical for 25ms window at 16kHz
|
45 |
+
hop_length: int = 160, # Typical for 10ms hop at 16kHz
|
46 |
+
**kwargs,
|
47 |
+
):
|
48 |
+
super().__init__(**kwargs)
|
49 |
+
self.hidden_size = hidden_size
|
50 |
+
self.num_hidden_layers = num_hidden_layers
|
51 |
+
self.sample_rate = sample_rate
|
52 |
+
self.n_mels = n_mels
|
53 |
+
self.image_token_index = image_token_index
|
54 |
+
self.n_fft = n_fft
|
55 |
+
self.hop_length = hop_length
|
56 |
+
|
57 |
+
|
58 |
+
import torch
|
59 |
+
from torch import nn
|
60 |
+
|
61 |
+
|
62 |
+
class LayerWiseWeightedSum(nn.Module):
|
63 |
+
def __init__(self, num_layers: int, learnable: bool = True):
|
64 |
+
super().__init__()
|
65 |
+
self.num_layers = num_layers
|
66 |
+
if learnable:
|
67 |
+
self.scalar = nn.Parameter(torch.zeros(num_layers))
|
68 |
+
else:
|
69 |
+
self.register_buffer("scalar", torch.zeros(num_layers))
|
70 |
+
|
71 |
+
def forward(self, hidden_states: list[torch.Tensor]) -> torch.Tensor:
|
72 |
+
assert len(hidden_states) == self.num_layers
|
73 |
+
norm_w = torch.softmax(self.scalar, dim=0).view(-1, 1, 1, 1)
|
74 |
+
stacked = torch.stack(hidden_states, dim=0)
|
75 |
+
return (norm_w * stacked).sum(dim=0)
|
76 |
+
|
77 |
+
|
78 |
+
class Gemma3AudioProjector(PreTrainedModel):
|
79 |
+
"""Conformer-based audio encoder → project to LM hidden-dim."""
|
80 |
+
|
81 |
+
config_class = Gemma3AudioProjectorConfig
|
82 |
+
base_model_prefix = "audio_projector"
|
83 |
+
|
84 |
+
def __init__(self, config: Gemma3AudioProjectorConfig):
|
85 |
+
super().__init__(config)
|
86 |
+
encoder_config = {
|
87 |
+
"activation": "swish",
|
88 |
+
"activation_checkpointing": "",
|
89 |
+
"attention_dim": 1024,
|
90 |
+
"attention_heads": 16,
|
91 |
+
"batch_norm": False,
|
92 |
+
"bias_in_glu": True,
|
93 |
+
"causal": True,
|
94 |
+
"chunk_size": -1,
|
95 |
+
"conv_activation": "swish",
|
96 |
+
"conv_glu_type": "swish",
|
97 |
+
"depthwise_multiplier": 1,
|
98 |
+
"depthwise_seperable_out_channel": 1024,
|
99 |
+
"dropout_rate": 0.0,
|
100 |
+
"encoder_embedding_config": {
|
101 |
+
"input_size": config.n_mels # This is feat_in for NemoConvSubsampling
|
102 |
+
},
|
103 |
+
"ext_pw_kernel_size": 1,
|
104 |
+
"ext_pw_out_channel": 1024,
|
105 |
+
"input_layer": "nemo_conv",
|
106 |
+
"input_size": config.n_mels, # Also feat_in for NemoConvSubsampling, consistency
|
107 |
+
"kernel_size": 3,
|
108 |
+
"left_chunk": 18,
|
109 |
+
"linear_units": 1536,
|
110 |
+
"nemo_conv_settings": {
|
111 |
+
"conv_channels": 1024,
|
112 |
+
},
|
113 |
+
"num_blocks": 24,
|
114 |
+
"relative_attention_bias_args": {
|
115 |
+
"t5_bias_max_distance": 500,
|
116 |
+
"type": "t5"
|
117 |
+
},
|
118 |
+
"time_reduction": 8
|
119 |
+
}
|
120 |
+
self.encoder = ConformerEncoder(**encoder_config)
|
121 |
+
self.layer_weighter = LayerWiseWeightedSum(
|
122 |
+
num_layers=encoder_config["num_blocks"]
|
123 |
+
)
|
124 |
+
self.proj = nn.Linear(encoder_config['attention_dim'], config.hidden_size, bias=False)
|
125 |
+
|
126 |
+
def forward(self, mel: torch.Tensor, mel_mask: torch.Tensor):
|
127 |
+
mel = mel.squeeze(1) # (B, T, 80)
|
128 |
+
mel_mask = mel_mask.squeeze(1) # (B, L)
|
129 |
+
|
130 |
+
if mel_mask.size(1) != mel.size(1):
|
131 |
+
mel_mask = mel_mask[..., : mel.size(1)]
|
132 |
+
|
133 |
+
_, out_mask, hidden_list = self.encoder(mel, mel_mask)
|
134 |
+
hidden_sum = self.layer_weighter(hidden_list)
|
135 |
+
hidden = self.proj(hidden_sum)
|
136 |
+
return hidden, out_mask
|
137 |
+
|
138 |
+
|
139 |
+
class Gemma3VisionProjector(nn.Module):
|
140 |
+
def __init__(self, config):
|
141 |
+
super().__init__()
|
142 |
+
self.mm_input_projection_weight = nn.Parameter(
|
143 |
+
torch.zeros(config.vision_config.hidden_size, config.text_config.hidden_size)
|
144 |
+
)
|
145 |
+
self.mm_soft_emb_norm = Gemma3RMSNorm(
|
146 |
+
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
|
147 |
+
)
|
148 |
+
self.patches_per_image = config.vision_config.image_size // config.vision_config.patch_size
|
149 |
+
self.tokens_per_side = int(config.mm_tokens_per_image ** 0.5)
|
150 |
+
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
151 |
+
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
|
152 |
+
|
153 |
+
def forward(self, vision_outputs: torch.Tensor):
|
154 |
+
b, _, seq_len = vision_outputs.shape
|
155 |
+
x = vision_outputs.transpose(1, 2).reshape(
|
156 |
+
b, seq_len, self.patches_per_image, self.patches_per_image
|
157 |
+
)
|
158 |
+
x = self.avg_pool(x).flatten(2).transpose(1, 2)
|
159 |
+
x = self.mm_soft_emb_norm(x)
|
160 |
+
return torch.matmul(x, self.mm_input_projection_weight).type_as(vision_outputs)
|
161 |
+
|
162 |
+
|
163 |
+
def token_type_ids_mask_function(token_type_ids: Optional[torch.Tensor]) -> Optional[Callable]:
|
164 |
+
"""
|
165 |
+
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
166 |
+
not start and end indices.
|
167 |
+
"""
|
168 |
+
# Do not return an additional mask in this case
|
169 |
+
if token_type_ids is None:
|
170 |
+
return None
|
171 |
+
|
172 |
+
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
173 |
+
# If it's 1, we need to unmask it
|
174 |
+
return token_type_ids[batch_idx, kv_idx] == 1
|
175 |
+
|
176 |
+
return inner_mask
|
177 |
+
|
178 |
+
|
179 |
+
class Gemma3OmniModel(Gemma3PreTrainedModel):
|
180 |
+
config_class = Gemma3OmniConfig
|
181 |
+
|
182 |
+
def __init__(self, config):
|
183 |
+
super().__init__(config)
|
184 |
+
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
185 |
+
self.multi_modal_projector = Gemma3VisionProjector(config)
|
186 |
+
self.audio_projector = Gemma3AudioProjector(
|
187 |
+
Gemma3AudioProjectorConfig(hidden_size=config.text_config.hidden_size)
|
188 |
+
)
|
189 |
+
self.vocab_size = config.text_config.vocab_size
|
190 |
+
|
191 |
+
language_model = AutoModel.from_config(config=config.text_config)
|
192 |
+
self.language_model = language_model
|
193 |
+
|
194 |
+
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
195 |
+
self.post_init()
|
196 |
+
|
197 |
+
def get_input_embeddings(self):
|
198 |
+
return self.language_model.embed_tokens
|
199 |
+
|
200 |
+
def forward(
|
201 |
+
self,
|
202 |
+
input_ids: torch.LongTensor = None,
|
203 |
+
pixel_values: torch.FloatTensor = None,
|
204 |
+
input_audio_embeds: Optional[torch.FloatTensor] = None,
|
205 |
+
audio_attention_mask: Optional[torch.LongTensor] = None,
|
206 |
+
attention_mask: Optional[torch.Tensor] = None,
|
207 |
+
position_ids: Optional[torch.LongTensor] = None,
|
208 |
+
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
209 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
210 |
+
cache_position: Optional[torch.LongTensor] = None,
|
211 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
212 |
+
labels: Optional[torch.LongTensor] = None,
|
213 |
+
use_cache: Optional[bool] = None,
|
214 |
+
output_attentions: Optional[bool] = None,
|
215 |
+
output_hidden_states: Optional[bool] = None,
|
216 |
+
return_dict: Optional[bool] = None,
|
217 |
+
**lm_kwargs,
|
218 |
+
) -> Union[Tuple, Gemma3ModelOutputWithPast]:
|
219 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
220 |
+
print("input_ids:", input_ids, "inputs_embeds:", inputs_embeds)
|
221 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
222 |
+
|
223 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
224 |
+
output_hidden_states = (
|
225 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
226 |
+
)
|
227 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
228 |
+
|
229 |
+
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
230 |
+
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
231 |
+
special_image_mask = input_ids == self.config.image_token_id
|
232 |
+
llm_input_ids = input_ids.clone()
|
233 |
+
llm_input_ids[special_image_mask] = 0
|
234 |
+
else:
|
235 |
+
llm_input_ids = input_ids
|
236 |
+
|
237 |
+
if inputs_embeds is None:
|
238 |
+
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
239 |
+
|
240 |
+
if cache_position is None:
|
241 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
242 |
+
cache_position = torch.arange(
|
243 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
244 |
+
)
|
245 |
+
|
246 |
+
if pixel_values is not None and past_key_values is None:
|
247 |
+
image_features = self.get_image_features(pixel_values)
|
248 |
+
|
249 |
+
if input_ids is None:
|
250 |
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
251 |
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1)
|
255 |
+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
256 |
+
|
257 |
+
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
|
258 |
+
image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
|
259 |
+
raise ValueError(
|
260 |
+
f"Number of images does not match number of special image tokens in the input text. "
|
261 |
+
f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
|
262 |
+
"tokens from image embeddings."
|
263 |
+
)
|
264 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
265 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
266 |
+
|
267 |
+
if input_audio_embeds is not None and past_key_values is None:
|
268 |
+
audio_features, audio_feat_mask = self.audio_projector(
|
269 |
+
input_audio_embeds, audio_attention_mask
|
270 |
+
)
|
271 |
+
if input_ids is None:
|
272 |
+
special_audio_mask = (
|
273 |
+
inputs_embeds
|
274 |
+
== self.get_input_embeddings()(
|
275 |
+
torch.tensor(
|
276 |
+
self.config.audio_token_index,
|
277 |
+
dtype=torch.long,
|
278 |
+
device=inputs_embeds.device,
|
279 |
+
)
|
280 |
+
)
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
special_audio_mask = (
|
284 |
+
input_ids == self.config.audio_token_index
|
285 |
+
).unsqueeze(-1) # [B, L, 1]
|
286 |
+
special_audio_mask = special_audio_mask.expand_as(inputs_embeds).to(
|
287 |
+
inputs_embeds.device
|
288 |
+
)
|
289 |
+
if (
|
290 |
+
not is_torchdynamo_compiling()
|
291 |
+
and inputs_embeds[special_audio_mask].numel() != audio_features.numel()
|
292 |
+
):
|
293 |
+
audio_tokens_in_text = special_audio_mask.sum(dim=1).sum(dim=0)[0]
|
294 |
+
raise ValueError(
|
295 |
+
f"Number of audio tokens in the text ({audio_tokens_in_text}) "
|
296 |
+
f"≠ number of tokens from audio embeddings "
|
297 |
+
f"({audio_features.shape[0] * audio_features.shape[1]})."
|
298 |
+
)
|
299 |
+
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
300 |
+
audio_features = audio_features.reshape(-1)
|
301 |
+
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_features)
|
302 |
+
|
303 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
304 |
+
# Prepare mask arguments
|
305 |
+
mask_kwargs = {
|
306 |
+
"config": self.config.get_text_config(),
|
307 |
+
"input_embeds": inputs_embeds,
|
308 |
+
"attention_mask": attention_mask,
|
309 |
+
"cache_position": cache_position,
|
310 |
+
"past_key_values": past_key_values,
|
311 |
+
}
|
312 |
+
if token_type_ids is not None and inputs_embeds.shape[1] != 1:
|
313 |
+
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(
|
314 |
+
token_type_ids.to(cache_position.device)
|
315 |
+
)
|
316 |
+
|
317 |
+
# Create the masks
|
318 |
+
causal_mask_mapping = {
|
319 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
320 |
+
"sliding_attention": create_sliding_window_causal_mask(**mask_kwargs),
|
321 |
+
}
|
322 |
+
|
323 |
+
outputs = self.language_model(
|
324 |
+
attention_mask=causal_mask_mapping,
|
325 |
+
position_ids=position_ids,
|
326 |
+
past_key_values=past_key_values,
|
327 |
+
inputs_embeds=inputs_embeds,
|
328 |
+
use_cache=use_cache,
|
329 |
+
output_attentions=output_attentions,
|
330 |
+
output_hidden_states=output_hidden_states,
|
331 |
+
return_dict=True,
|
332 |
+
cache_position=cache_position,
|
333 |
+
**lm_kwargs,
|
334 |
+
)
|
335 |
+
|
336 |
+
return Gemma3ModelOutputWithPast(
|
337 |
+
last_hidden_state=outputs.last_hidden_state,
|
338 |
+
past_key_values=outputs.past_key_values if use_cache else None,
|
339 |
+
hidden_states=outputs.hidden_states,
|
340 |
+
attentions=outputs.attentions,
|
341 |
+
image_hidden_states=image_features if pixel_values is not None else None,
|
342 |
+
)
|
343 |
+
|
344 |
+
|
345 |
+
class Gemma3OmniForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
|
346 |
+
config_class = Gemma3OmniConfig
|
347 |
+
"""Gemma-3 Omni:vision + audio + text causal LM."""
|
348 |
+
_checkpoint_conversion_mapping = {
|
349 |
+
"^language_model.model": "model.language_model",
|
350 |
+
"^vision_tower": "model.vision_tower",
|
351 |
+
"^multi_modal_projector": "model.multi_modal_projector",
|
352 |
+
"^language_model.lm_head": "lm_head",
|
353 |
+
}
|
354 |
+
_tied_weights_keys = ["lm_head.weight"]
|
355 |
+
|
356 |
+
def __init__(self, config):
|
357 |
+
super().__init__(config)
|
358 |
+
self.model = Gemma3OmniModel(config)
|
359 |
+
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
360 |
+
self.post_init()
|
361 |
+
|
362 |
+
def get_input_embeddings(self):
|
363 |
+
return self.model.language_model.embed_tokens
|
364 |
+
|
365 |
+
def forward(
|
366 |
+
self,
|
367 |
+
input_ids: torch.LongTensor = None,
|
368 |
+
pixel_values: torch.FloatTensor = None,
|
369 |
+
input_audio_embeds: Optional[torch.FloatTensor] = None,
|
370 |
+
audio_attention_mask: Optional[torch.LongTensor] = None,
|
371 |
+
attention_mask: Optional[torch.Tensor] = None,
|
372 |
+
position_ids: Optional[torch.LongTensor] = None,
|
373 |
+
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
|
374 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
375 |
+
cache_position: Optional[torch.LongTensor] = None,
|
376 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
377 |
+
labels: Optional[torch.LongTensor] = None,
|
378 |
+
use_cache: Optional[bool] = None,
|
379 |
+
output_attentions: Optional[bool] = None,
|
380 |
+
output_hidden_states: Optional[bool] = None,
|
381 |
+
return_dict: Optional[bool] = None,
|
382 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
383 |
+
**lm_kwargs,
|
384 |
+
) -> Union[Tuple, Gemma3CausalLMOutputWithPast]:
|
385 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
386 |
+
output_hidden_states = (
|
387 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
388 |
+
)
|
389 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
390 |
+
|
391 |
+
outputs = self.model(
|
392 |
+
input_ids=input_ids,
|
393 |
+
pixel_values=pixel_values,
|
394 |
+
input_audio_embeds=input_audio_embeds,
|
395 |
+
audio_attention_mask=audio_attention_mask,
|
396 |
+
token_type_ids=token_type_ids,
|
397 |
+
attention_mask=attention_mask,
|
398 |
+
position_ids=position_ids,
|
399 |
+
past_key_values=past_key_values,
|
400 |
+
inputs_embeds=inputs_embeds,
|
401 |
+
use_cache=use_cache,
|
402 |
+
labels=labels,
|
403 |
+
output_attentions=output_attentions,
|
404 |
+
output_hidden_states=output_hidden_states,
|
405 |
+
return_dict=return_dict,
|
406 |
+
cache_position=cache_position,
|
407 |
+
**lm_kwargs,
|
408 |
+
)
|
409 |
+
|
410 |
+
hidden_states = outputs[0]
|
411 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
412 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
413 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
414 |
+
|
415 |
+
loss = None
|
416 |
+
if labels is not None:
|
417 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
418 |
+
logits = logits.float()
|
419 |
+
shift_logits = logits[..., :-1, :]
|
420 |
+
shift_labels = labels[..., 1:]
|
421 |
+
if attention_mask is not None:
|
422 |
+
# we use the input attention mask to shift the logits and labels, because it is 2D.
|
423 |
+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
|
424 |
+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1]:].to(logits.device)
|
425 |
+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
|
426 |
+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
|
427 |
+
else:
|
428 |
+
shift_logits = shift_logits.contiguous()
|
429 |
+
shift_labels = shift_labels.contiguous()
|
430 |
+
# Flatten the tokens
|
431 |
+
|
432 |
+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
433 |
+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
|
434 |
+
|
435 |
+
if LigerFusedLinearCrossEntropyLoss is not None:
|
436 |
+
loss_fct = LigerFusedLinearCrossEntropyLoss()
|
437 |
+
else:
|
438 |
+
loss_fct = nn.CrossEntropyLoss()
|
439 |
+
loss = loss_fct(flat_logits, flat_labels)
|
440 |
+
|
441 |
+
if not return_dict:
|
442 |
+
output = (logits,) + outputs[1:]
|
443 |
+
return (loss,) + output if loss is not None else output
|
444 |
+
|
445 |
+
return Gemma3CausalLMOutputWithPast(
|
446 |
+
loss=loss,
|
447 |
+
logits=logits,
|
448 |
+
past_key_values=outputs.past_key_values,
|
449 |
+
hidden_states=outputs.hidden_states,
|
450 |
+
attentions=outputs.attentions,
|
451 |
+
image_hidden_states=outputs.image_hidden_states,
|
452 |
+
)
|
453 |
+
|
454 |
+
|
455 |
+
__all__ = [
|
456 |
+
"Gemma3AudioProjectorConfig",
|
457 |
+
"Gemma3AudioProjector",
|
458 |
+
"Gemma3VisionProjector",
|
459 |
+
"Gemma3OmniModel",
|
460 |
+
"Gemma3OmniForConditionalGeneration",
|
461 |
+
]
|
preprocessor_config.json
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoProcessor": "processing_gemma3_omni.Gemma3OmniProcessor",
|
4 |
+
"AutoFeatureExtractor": "processing_gemma3_omni.Gemma3AudioFeatureExtractor"
|
5 |
+
},
|
6 |
+
"do_convert_rgb": null,
|
7 |
+
"do_normalize": true,
|
8 |
+
"do_pan_and_scan": null,
|
9 |
+
"do_rescale": true,
|
10 |
+
"do_resize": true,
|
11 |
+
"image_mean": [
|
12 |
+
0.5,
|
13 |
+
0.5,
|
14 |
+
0.5
|
15 |
+
],
|
16 |
+
"image_processor_type": "Gemma3ImageProcessor",
|
17 |
+
"processor_class": "Gemma3Processor",
|
18 |
+
"image_seq_length": 256,
|
19 |
+
"image_std": [
|
20 |
+
0.5,
|
21 |
+
0.5,
|
22 |
+
0.5
|
23 |
+
],
|
24 |
+
"pan_and_scan_max_num_crops": null,
|
25 |
+
"pan_and_scan_min_crop_size": null,
|
26 |
+
"pan_and_scan_min_ratio_to_activate": null,
|
27 |
+
"resample": 2,
|
28 |
+
"rescale_factor": 0.00392156862745098,
|
29 |
+
"size": {
|
30 |
+
"height": 896,
|
31 |
+
"width": 896
|
32 |
+
},
|
33 |
+
"compression_rate": 4,
|
34 |
+
"feat_stride": 4,
|
35 |
+
"feature_extractor_type": "Gemma3AudioFeatureExtractor",
|
36 |
+
"feature_size": 80,
|
37 |
+
"hop_length": 160,
|
38 |
+
"n_fft": 512,
|
39 |
+
"padding_side": "right",
|
40 |
+
"padding_value": 0.0,
|
41 |
+
"processor_class": "Gemma3OmniProcessor",
|
42 |
+
"qformer_rate": 2,
|
43 |
+
"return_attention_mask": true,
|
44 |
+
"sampling_rate": 16000,
|
45 |
+
"win_length": 400
|
46 |
+
}
|
processing_gemma3_omni.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Union, Dict, Any, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import scipy.signal
|
5 |
+
import torch
|
6 |
+
from torch.nn.utils.rnn import pad_sequence
|
7 |
+
from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
|
8 |
+
from transformers.feature_extraction_utils import BatchFeature
|
9 |
+
from transformers.image_utils import make_nested_list_of_images
|
10 |
+
from transformers.processing_utils import ProcessorMixin
|
11 |
+
from transformers.utils import TensorType, logging
|
12 |
+
|
13 |
+
DEFAULT_SPECIAL_TOKENS = {
|
14 |
+
"bos_token": "<bos>",
|
15 |
+
"eos_token": "<eos>",
|
16 |
+
"pad_token": "<pad>",
|
17 |
+
"unk_token": "<unk>",
|
18 |
+
"boi_token": "<start_of_image>",
|
19 |
+
"eoi_token": "<end_of_image>",
|
20 |
+
"image_token": "<image_soft_token>",
|
21 |
+
"boa_token": "<start_of_audio>",
|
22 |
+
"eoa_token": "<end_of_audio>",
|
23 |
+
"audio_token": "<audio_soft_token>",
|
24 |
+
}
|
25 |
+
DEFAULT_SAMPLING_RATE = 16000
|
26 |
+
DEFAULT_N_FFT = 512
|
27 |
+
DEFAULT_WIN_LENGTH = 400
|
28 |
+
DEFAULT_HOP_LENGTH = 160
|
29 |
+
DEFAULT_N_MELS = 80
|
30 |
+
DEFAULT_COMPRESSION_RATE = 4
|
31 |
+
DEFAULT_QFORMER_RATE = 8
|
32 |
+
DEFAULT_FEAT_STRIDE = 4
|
33 |
+
DEFAULT_IMAGE_SEQ_LENGTH = 256
|
34 |
+
DEFAULT_MAX_LENGTH = 16384
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
def compute_audio_token_count(
|
40 |
+
mel_frame_count: int,
|
41 |
+
*,
|
42 |
+
feat_stride: int = DEFAULT_FEAT_STRIDE,
|
43 |
+
compression_rate: int = DEFAULT_COMPRESSION_RATE,
|
44 |
+
qformer_rate: int = DEFAULT_QFORMER_RATE,
|
45 |
+
) -> int:
|
46 |
+
audio_frames = mel_frame_count * feat_stride
|
47 |
+
audio_frames = (audio_frames + compression_rate - 1) // compression_rate
|
48 |
+
audio_frames = (audio_frames + qformer_rate - 1) // qformer_rate
|
49 |
+
return audio_frames
|
50 |
+
|
51 |
+
|
52 |
+
def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None):
|
53 |
+
bank_width = int(n_fft // 2 + 1)
|
54 |
+
if fmax is None:
|
55 |
+
fmax = sample_rate / 2
|
56 |
+
if fmin is None:
|
57 |
+
fmin = 0
|
58 |
+
|
59 |
+
def mel(f):
|
60 |
+
return 1127.0 * np.log(1.0 + f / 700.0)
|
61 |
+
|
62 |
+
def bin2mel(fft_bin):
|
63 |
+
return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0))
|
64 |
+
|
65 |
+
def f2bin(f):
|
66 |
+
return int((f * n_fft / sample_rate) + 0.5)
|
67 |
+
|
68 |
+
klo = f2bin(fmin) + 1
|
69 |
+
khi = f2bin(fmax)
|
70 |
+
khi = max(khi, klo)
|
71 |
+
mlo = mel(fmin)
|
72 |
+
mhi = mel(fmax)
|
73 |
+
m_centers = np.linspace(mlo, mhi, n_mels + 2)
|
74 |
+
ms = (mhi - mlo) / (n_mels + 1)
|
75 |
+
matrix = np.zeros((n_mels, bank_width), dtype=np.float32)
|
76 |
+
for m in range(n_mels):
|
77 |
+
left = m_centers[m]
|
78 |
+
center = m_centers[m + 1]
|
79 |
+
right = m_centers[m + 2]
|
80 |
+
for fft_bin in range(klo, khi):
|
81 |
+
mbin = bin2mel(fft_bin)
|
82 |
+
if left < mbin < right:
|
83 |
+
matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms
|
84 |
+
return matrix
|
85 |
+
|
86 |
+
|
87 |
+
class Gemma3AudioFeatureExtractor(SequenceFeatureExtractor):
|
88 |
+
model_input_names = ["input_audio_embeds", "audio_embed_sizes", "audio_attention_mask"]
|
89 |
+
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
audio_compression_rate: int = DEFAULT_COMPRESSION_RATE,
|
93 |
+
audio_downsample_rate: int = DEFAULT_QFORMER_RATE,
|
94 |
+
audio_feat_stride: int = DEFAULT_FEAT_STRIDE,
|
95 |
+
feature_size: int = DEFAULT_N_MELS,
|
96 |
+
sampling_rate: int = DEFAULT_SAMPLING_RATE,
|
97 |
+
padding_value: float = 0.0,
|
98 |
+
eightk_method: str = "fillzero",
|
99 |
+
**kwargs,
|
100 |
+
):
|
101 |
+
super().__init__(
|
102 |
+
feature_size=kwargs.pop("feature_size", feature_size),
|
103 |
+
sampling_rate=kwargs.pop("sampling_rate", sampling_rate),
|
104 |
+
padding_value=kwargs.pop("padding_value", padding_value),
|
105 |
+
**kwargs,
|
106 |
+
)
|
107 |
+
self.compression_rate = audio_compression_rate
|
108 |
+
self.qformer_compression_rate = audio_downsample_rate
|
109 |
+
self.feat_stride = audio_feat_stride
|
110 |
+
self._eightk_method = eightk_method
|
111 |
+
self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T
|
112 |
+
self._hamming400 = np.hamming(400)
|
113 |
+
self._hamming200 = np.hamming(200)
|
114 |
+
|
115 |
+
def __call__(
|
116 |
+
self,
|
117 |
+
audios: List[Tuple[np.ndarray, int]],
|
118 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
119 |
+
):
|
120 |
+
returned_input_audio_embeds = []
|
121 |
+
returned_audio_embed_sizes = []
|
122 |
+
audio_frames_list = []
|
123 |
+
for audio_data, sample_rate in audios:
|
124 |
+
if isinstance(audio_data, list):
|
125 |
+
audio_data = np.array(audio_data, dtype=np.float32)
|
126 |
+
if not isinstance(audio_data, np.ndarray):
|
127 |
+
raise TypeError(f"Waveform data must be a numpy array, got {type(audio_data)}")
|
128 |
+
audio_embeds_np = self._extract_features(audio_data, sample_rate)
|
129 |
+
num_mel_frames = audio_embeds_np.shape[0]
|
130 |
+
current_audio_frames = num_mel_frames * self.feat_stride
|
131 |
+
audio_embed_size = self._compute_audio_embed_size(current_audio_frames)
|
132 |
+
returned_input_audio_embeds.append(torch.from_numpy(audio_embeds_np))
|
133 |
+
returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long())
|
134 |
+
audio_frames_list.append(current_audio_frames)
|
135 |
+
padded_input_audio_embeds = pad_sequence(
|
136 |
+
returned_input_audio_embeds, batch_first=True, padding_value=self.padding_value
|
137 |
+
)
|
138 |
+
stacked_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0)
|
139 |
+
tensor_audio_frames = torch.tensor(audio_frames_list, dtype=torch.long)
|
140 |
+
max_audio_frames = tensor_audio_frames.max().item() if tensor_audio_frames.numel() > 0 else 0
|
141 |
+
if max_audio_frames > 0 and len(audios) > 1:
|
142 |
+
audio_attention_mask = (
|
143 |
+
torch.arange(0, max_audio_frames, device=tensor_audio_frames.device).unsqueeze(0)
|
144 |
+
< tensor_audio_frames.unsqueeze(1)
|
145 |
+
)
|
146 |
+
elif max_audio_frames > 0:
|
147 |
+
audio_attention_mask = torch.ones(1, max_audio_frames, dtype=torch.bool, device=tensor_audio_frames.device)
|
148 |
+
else:
|
149 |
+
audio_attention_mask = None
|
150 |
+
data = {
|
151 |
+
"input_audio_embeds": padded_input_audio_embeds,
|
152 |
+
"audio_embed_sizes": stacked_audio_embed_sizes,
|
153 |
+
}
|
154 |
+
if audio_attention_mask is not None:
|
155 |
+
data["audio_attention_mask"] = audio_attention_mask
|
156 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
157 |
+
|
158 |
+
def _extract_spectrogram(self, wav: np.ndarray, fs: int) -> np.ndarray:
|
159 |
+
if wav.ndim > 1:
|
160 |
+
wav = np.squeeze(wav)
|
161 |
+
if len(wav.shape) == 2:
|
162 |
+
wav = wav.mean(axis=1).astype(np.float32)
|
163 |
+
wav = wav.astype(np.float32)
|
164 |
+
current_fs = fs
|
165 |
+
if current_fs > self.sampling_rate:
|
166 |
+
wav = scipy.signal.resample_poly(wav, self.sampling_rate, current_fs)
|
167 |
+
current_fs = self.sampling_rate
|
168 |
+
elif 8000 < current_fs < self.sampling_rate:
|
169 |
+
wav = scipy.signal.resample_poly(wav, 8000, current_fs)
|
170 |
+
current_fs = 8000
|
171 |
+
elif current_fs < 8000 and current_fs > 0:
|
172 |
+
wav = scipy.signal.resample_poly(wav, 8000, current_fs)
|
173 |
+
current_fs = 8000
|
174 |
+
elif current_fs <= 0:
|
175 |
+
raise RuntimeError(f"Unsupported sample rate {current_fs}")
|
176 |
+
if current_fs == 8000 and self._eightk_method == "resample":
|
177 |
+
wav = scipy.signal.resample_poly(wav, self.sampling_rate, 8000)
|
178 |
+
current_fs = self.sampling_rate
|
179 |
+
elif current_fs != self.sampling_rate:
|
180 |
+
raise RuntimeError(
|
181 |
+
f"Audio sample rate {current_fs} not supported. Expected {self.sampling_rate} or 8000 for 8k methods.")
|
182 |
+
preemphasis_coeff = 0.97
|
183 |
+
if current_fs == 8000:
|
184 |
+
n_fft, win_length, hop_length, fft_window = 256, 200, 80, self._hamming200
|
185 |
+
else:
|
186 |
+
n_fft, win_length, hop_length, fft_window = 512, 400, 160, self._hamming400
|
187 |
+
if len(wav) < win_length:
|
188 |
+
wav = np.pad(wav, (0, win_length - len(wav)), 'constant', constant_values=(0.0,))
|
189 |
+
num_frames = (wav.shape[0] - win_length) // hop_length + 1
|
190 |
+
if num_frames <= 0:
|
191 |
+
return np.zeros((0, n_fft // 2 + 1), dtype=np.float32)
|
192 |
+
y_frames = np.array(
|
193 |
+
[wav[i * hop_length: i * hop_length + win_length] for i in range(num_frames)],
|
194 |
+
dtype=np.float32,
|
195 |
+
)
|
196 |
+
_y_frames_rolled = np.roll(y_frames, 1, axis=1)
|
197 |
+
_y_frames_rolled[:, 0] = _y_frames_rolled[:, 1]
|
198 |
+
y_frames_preemphasized = (y_frames - preemphasis_coeff * _y_frames_rolled) * 32768.0
|
199 |
+
S = np.fft.rfft(fft_window * y_frames_preemphasized, n=n_fft, axis=1).astype(np.complex64)
|
200 |
+
if current_fs == 8000 and self._eightk_method == "fillzero":
|
201 |
+
target_bins = (512 // 2) + 1
|
202 |
+
S_core = S[:, :-1]
|
203 |
+
padarray = np.zeros((S_core.shape[0], target_bins - S_core.shape[1]), dtype=S.dtype)
|
204 |
+
S = np.concatenate((S_core, padarray), axis=1)
|
205 |
+
spec = np.abs(S).astype(np.float32)
|
206 |
+
return spec
|
207 |
+
|
208 |
+
def _extract_features(self, wav: np.ndarray, fs: int) -> np.ndarray:
|
209 |
+
spec = self._extract_spectrogram(wav, fs)
|
210 |
+
if spec.shape[0] == 0:
|
211 |
+
return np.zeros((0, self.feature_size), dtype=np.float32)
|
212 |
+
spec_power = spec ** 2
|
213 |
+
fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None)
|
214 |
+
log_fbank = np.log(fbank_power).astype(np.float32)
|
215 |
+
return log_fbank
|
216 |
+
|
217 |
+
def _compute_audio_embed_size(self, audio_frames: int) -> int:
|
218 |
+
integer = audio_frames // self.compression_rate
|
219 |
+
remainder = audio_frames % self.compression_rate
|
220 |
+
result = integer if remainder == 0 else integer + 1
|
221 |
+
integer = result // self.qformer_compression_rate
|
222 |
+
remainder = result % self.qformer_compression_rate
|
223 |
+
result = integer if remainder == 0 else integer + 1
|
224 |
+
return result
|
225 |
+
|
226 |
+
|
227 |
+
class Gemma3OmniProcessor(ProcessorMixin):
|
228 |
+
attributes = ["image_processor", "audio_processor", "tokenizer"]
|
229 |
+
image_processor_class = "AutoImageProcessor"
|
230 |
+
audio_processor_class = "AutoFeatureExtractor"
|
231 |
+
tokenizer_class = "AutoTokenizer"
|
232 |
+
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
image_processor=None,
|
236 |
+
audio_processor=None,
|
237 |
+
tokenizer=None,
|
238 |
+
special_tokens: Optional[Dict[str, str]] = None,
|
239 |
+
image_seq_length: int = DEFAULT_IMAGE_SEQ_LENGTH,
|
240 |
+
prompt_audio_compression_rate: int = DEFAULT_COMPRESSION_RATE,
|
241 |
+
prompt_audio_qformer_rate: int = DEFAULT_QFORMER_RATE,
|
242 |
+
audio_placeholder_token: str = "<|audio_placeholder|>",
|
243 |
+
**kwargs,
|
244 |
+
):
|
245 |
+
super().__init__(
|
246 |
+
image_processor=image_processor,
|
247 |
+
audio_processor=audio_processor,
|
248 |
+
tokenizer=tokenizer,
|
249 |
+
**kwargs,
|
250 |
+
)
|
251 |
+
self.special_tokens = dict(DEFAULT_SPECIAL_TOKENS)
|
252 |
+
if special_tokens is not None:
|
253 |
+
self.special_tokens.update(special_tokens)
|
254 |
+
if tokenizer is not None:
|
255 |
+
for key in self.special_tokens:
|
256 |
+
val = getattr(tokenizer, key, None)
|
257 |
+
if isinstance(val, str):
|
258 |
+
self.special_tokens[key] = val
|
259 |
+
self.image_token = self.special_tokens["image_token"]
|
260 |
+
self.audio_token = self.special_tokens["audio_token"]
|
261 |
+
self.boi_token = self.special_tokens["boi_token"]
|
262 |
+
self.eoi_token = self.special_tokens["eoi_token"]
|
263 |
+
self.boa_token = self.special_tokens["boa_token"]
|
264 |
+
self.eoa_token = self.special_tokens["eoa_token"]
|
265 |
+
self.image_seq_length = image_seq_length
|
266 |
+
self.full_image_sequence = f"{self.boi_token}{''.join([self.image_token] * self.image_seq_length)}{self.eoi_token}"
|
267 |
+
self.prompt_audio_compression_rate = prompt_audio_compression_rate
|
268 |
+
self.prompt_audio_qformer_rate = prompt_audio_qformer_rate
|
269 |
+
self.audio_placeholder_token = audio_placeholder_token
|
270 |
+
if self.tokenizer is not None:
|
271 |
+
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
|
272 |
+
self.audio_token_id = self.tokenizer.convert_tokens_to_ids(self.audio_token)
|
273 |
+
else:
|
274 |
+
self.image_token_id = None
|
275 |
+
self.audio_token_id = None
|
276 |
+
|
277 |
+
def compute_audio_token_count(self, mel_frame_count: int) -> int:
|
278 |
+
stride = getattr(self.audio_processor, "feat_stride", DEFAULT_FEAT_STRIDE)
|
279 |
+
return compute_audio_token_count(
|
280 |
+
mel_frame_count,
|
281 |
+
feat_stride=stride,
|
282 |
+
compression_rate=self.prompt_audio_compression_rate,
|
283 |
+
qformer_rate=self.prompt_audio_qformer_rate,
|
284 |
+
)
|
285 |
+
|
286 |
+
def apply_chat_template(
|
287 |
+
self,
|
288 |
+
messages,
|
289 |
+
add_generation_prompt: bool = True,
|
290 |
+
tokenize: bool = False,
|
291 |
+
**kwargs
|
292 |
+
) -> Union[str, Dict[str, Any]]:
|
293 |
+
prompt = ""
|
294 |
+
if isinstance(messages, dict) and "messages" in messages:
|
295 |
+
if "audios" in messages:
|
296 |
+
audios = messages["audios"]
|
297 |
+
if "audio" in messages:
|
298 |
+
audios = [messages["audio"]]
|
299 |
+
if "images" in messages:
|
300 |
+
images = messages["images"]
|
301 |
+
if "image" in messages:
|
302 |
+
images = [messages["image"]]
|
303 |
+
messages = messages["messages"]
|
304 |
+
|
305 |
+
for msg in messages:
|
306 |
+
role = msg.get("role", "")
|
307 |
+
prompt += f"<start_of_turn>{role}\n"
|
308 |
+
contents = msg.get("content", [])
|
309 |
+
if not isinstance(contents, list):
|
310 |
+
contents = [contents]
|
311 |
+
|
312 |
+
for c in contents:
|
313 |
+
if isinstance(c, dict):
|
314 |
+
ctype = c.get("type")
|
315 |
+
if ctype == "image":
|
316 |
+
idx = c.get("index")
|
317 |
+
img_data = None
|
318 |
+
if idx is not None and isinstance(idx, int):
|
319 |
+
img_data = images[idx]
|
320 |
+
elif "image" in c:
|
321 |
+
img_data = c["image"]
|
322 |
+
if img_data is None:
|
323 |
+
logger.warning("No image data found for image content: %s", c)
|
324 |
+
prompt += self.full_image_sequence
|
325 |
+
continue
|
326 |
+
|
327 |
+
if ctype == "audio":
|
328 |
+
idx = c.get("index")
|
329 |
+
aud_data = None
|
330 |
+
if idx is not None and isinstance(idx, int):
|
331 |
+
aud_data = audios[idx]["array"]
|
332 |
+
sr = audios[idx].get("sampling_rate",
|
333 |
+
self.audio_processor.sampling_rate if self.audio_processor else DEFAULT_SAMPLING_RATE)
|
334 |
+
elif "audio" in c:
|
335 |
+
aud_data = c["audio"]
|
336 |
+
sr = c.get("sampling_rate",
|
337 |
+
self.audio_processor.sampling_rate if self.audio_processor else DEFAULT_SAMPLING_RATE)
|
338 |
+
if aud_data is None:
|
339 |
+
logger.warning("No audio data found for audio content: %s", c)
|
340 |
+
|
341 |
+
n_audio_tokens = 0
|
342 |
+
if self.audio_processor:
|
343 |
+
features = self.audio_processor(audios=[(aud_data, sr)], return_tensors=None)
|
344 |
+
mel_frame_count = features["input_audio_embeds"].shape[1]
|
345 |
+
n_audio_tokens = self.compute_audio_token_count(mel_frame_count)
|
346 |
+
prompt += (
|
347 |
+
self.boa_token +
|
348 |
+
(self.audio_token * n_audio_tokens) +
|
349 |
+
self.eoa_token
|
350 |
+
)
|
351 |
+
continue
|
352 |
+
|
353 |
+
if ctype == "text" and "text" in c:
|
354 |
+
prompt += str(c["text"])
|
355 |
+
continue
|
356 |
+
continue
|
357 |
+
|
358 |
+
elif isinstance(c, str):
|
359 |
+
prompt += c
|
360 |
+
continue
|
361 |
+
else:
|
362 |
+
logger.warning("Unknown content type in message: %s", c)
|
363 |
+
continue
|
364 |
+
|
365 |
+
prompt += "<end_of_turn>\n"
|
366 |
+
|
367 |
+
if add_generation_prompt:
|
368 |
+
prompt += "<start_of_turn>model\n"
|
369 |
+
|
370 |
+
if tokenize and self.tokenizer is not None:
|
371 |
+
safe_kwargs = {}
|
372 |
+
allowed_keys = {"return_tensors", "padding", "truncation", "max_length", "add_special_tokens"}
|
373 |
+
for k, v in kwargs.items():
|
374 |
+
if k in allowed_keys:
|
375 |
+
safe_kwargs[k] = v
|
376 |
+
return self.tokenizer(prompt, **safe_kwargs)
|
377 |
+
|
378 |
+
return prompt
|
379 |
+
|
380 |
+
def __call__(
|
381 |
+
self,
|
382 |
+
text: Optional[Union[str, List[str]]] = None,
|
383 |
+
images: Optional[Union[Any, List[Any]]] = None,
|
384 |
+
audios: Optional[Union[Tuple[np.ndarray, int], List[Tuple[np.ndarray, int]]]] = None,
|
385 |
+
messages: Optional[List[Dict]] = None,
|
386 |
+
add_generation_prompt: bool = True,
|
387 |
+
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
388 |
+
device: Optional[str] = None,
|
389 |
+
**kwargs
|
390 |
+
) -> Dict[str, Any]:
|
391 |
+
if messages is not None:
|
392 |
+
if isinstance(messages, dict):
|
393 |
+
messages = [messages]
|
394 |
+
prompt = self.apply_chat_template(
|
395 |
+
messages,
|
396 |
+
add_generation_prompt=add_generation_prompt,
|
397 |
+
tokenize=False,
|
398 |
+
)
|
399 |
+
audio_inputs = []
|
400 |
+
for msg in messages:
|
401 |
+
contents = msg.get("content", [])
|
402 |
+
if not isinstance(contents, list):
|
403 |
+
contents = [contents]
|
404 |
+
for c in contents:
|
405 |
+
if isinstance(c, dict) and c.get("type") == "audio":
|
406 |
+
arr = c["audio"]
|
407 |
+
sr = c.get("sampling_rate",
|
408 |
+
self.audio_processor.sampling_rate if self.audio_processor else 16000)
|
409 |
+
audio_inputs.append((arr, sr))
|
410 |
+
audio_features = {}
|
411 |
+
if audio_inputs and self.audio_processor is not None:
|
412 |
+
audio_features = self.audio_processor(audios=audio_inputs, return_tensors=return_tensors)
|
413 |
+
text_features = self.tokenizer(prompt, return_tensors=return_tensors, padding=True, truncation=True,
|
414 |
+
max_length=DEFAULT_MAX_LENGTH)
|
415 |
+
inputs = {**text_features, **audio_features}
|
416 |
+
else:
|
417 |
+
if text is None and images is None and audios is None:
|
418 |
+
raise ValueError("At least one of text/images/audios/messages must be provided.")
|
419 |
+
num_samples = 1
|
420 |
+
if isinstance(text, list):
|
421 |
+
num_samples = len(text)
|
422 |
+
elif images is not None and isinstance(images, list):
|
423 |
+
num_samples = len(images)
|
424 |
+
elif audios is not None and isinstance(audios, list):
|
425 |
+
num_samples = len(audios)
|
426 |
+
image_features = {}
|
427 |
+
if images is not None and self.image_processor is not None:
|
428 |
+
batched_images = make_nested_list_of_images(images)
|
429 |
+
img_out = self.image_processor(batched_images, return_tensors=None)
|
430 |
+
image_features = img_out.data if isinstance(img_out, BatchFeature) else img_out
|
431 |
+
audio_features = {}
|
432 |
+
audio_token_counts = None
|
433 |
+
if audios is not None and self.audio_processor is not None:
|
434 |
+
audio_out = self.audio_processor(audios=audios, return_tensors=None)
|
435 |
+
audio_features = audio_out.data
|
436 |
+
att_mask = audio_features[self.audio_processor.model_input_names[2]]
|
437 |
+
if isinstance(att_mask, torch.Tensor):
|
438 |
+
frames_for_embed = att_mask.sum(dim=-1).cpu().tolist()
|
439 |
+
else:
|
440 |
+
frames_for_embed = np.array(att_mask).sum(axis=-1).tolist()
|
441 |
+
audio_token_counts = [self.compute_audio_token_count(mel_frame_count) for mel_frame_count in
|
442 |
+
frames_for_embed]
|
443 |
+
if text is None:
|
444 |
+
text = [""] * num_samples
|
445 |
+
elif isinstance(text, str):
|
446 |
+
text = [text]
|
447 |
+
prompts = []
|
448 |
+
for idx in range(num_samples):
|
449 |
+
prompt = text[idx]
|
450 |
+
has_image = images is not None
|
451 |
+
audio_count = audio_token_counts[idx] if audio_token_counts is not None else None
|
452 |
+
prompt_str = prompt
|
453 |
+
if has_image:
|
454 |
+
prompt_str = prompt_str.replace(self.boi_token, self.full_image_sequence)
|
455 |
+
if audio_count is not None:
|
456 |
+
prompt_str = prompt_str.replace(self.boa_token, self.boa_token + (self.audio_token * audio_count))
|
457 |
+
prompts.append(prompt_str)
|
458 |
+
text_features = self.tokenizer(prompts, return_tensors=return_tensors, padding=True, truncation=True,
|
459 |
+
max_length=DEFAULT_MAX_LENGTH)
|
460 |
+
inputs = {**text_features}
|
461 |
+
if image_features:
|
462 |
+
inputs.update(image_features)
|
463 |
+
if audio_features:
|
464 |
+
inputs.update(audio_features)
|
465 |
+
if device is not None:
|
466 |
+
inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
|
467 |
+
return inputs
|
468 |
+
|
469 |
+
@property
|
470 |
+
def model_input_names(self) -> List[str]:
|
471 |
+
input_names = set()
|
472 |
+
if hasattr(self, 'tokenizer') and self.tokenizer is not None:
|
473 |
+
tokenizer_inputs = self.tokenizer.model_input_names
|
474 |
+
if isinstance(tokenizer_inputs, (list, set)):
|
475 |
+
input_names.update(tokenizer_inputs)
|
476 |
+
else:
|
477 |
+
input_names.add(str(tokenizer_inputs))
|
478 |
+
input_names.add("token_type_ids")
|
479 |
+
if hasattr(self, 'image_processor') and self.image_processor is not None:
|
480 |
+
image_inputs = self.image_processor.model_input_names
|
481 |
+
if isinstance(image_inputs, (list, set)):
|
482 |
+
input_names.update(image_inputs)
|
483 |
+
else:
|
484 |
+
input_names.add(str(image_inputs))
|
485 |
+
if hasattr(self, 'audio_processor') and self.audio_processor is not None:
|
486 |
+
audio_inputs = self.audio_processor.model_input_names
|
487 |
+
if isinstance(audio_inputs, (list, set)):
|
488 |
+
input_names.update(audio_inputs)
|
489 |
+
else:
|
490 |
+
input_names.add(str(audio_inputs))
|
491 |
+
return list(input_names)
|
special_tokens_map.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"boi_token": "<start_of_image>",
|
3 |
+
"bos_token": {
|
4 |
+
"content": "<bos>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false
|
9 |
+
},
|
10 |
+
"eoi_token": "<end_of_image>",
|
11 |
+
"eos_token": {
|
12 |
+
"content": "<eos>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false
|
17 |
+
},
|
18 |
+
"image_token": "<image_soft_token>",
|
19 |
+
"boa_token": "<start_of_audio>",
|
20 |
+
"eoa_token": "<end_of_audio>",
|
21 |
+
"audio_token": "<audio_soft_token>",
|
22 |
+
"pad_token": {
|
23 |
+
"content": "<pad>",
|
24 |
+
"lstrip": false,
|
25 |
+
"normalized": false,
|
26 |
+
"rstrip": false,
|
27 |
+
"single_word": false
|
28 |
+
},
|
29 |
+
"unk_token": {
|
30 |
+
"content": "<unk>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false
|
35 |
+
}
|
36 |
+
}
|
speech_conformer_encoder.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:88787c7cf85c7d14c8dd2a29cc86f69a1a7d151f306ce00bb54fe7dc35284b0e
|
3 |
+
size 33384534
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1299c11d7cf632ef3b4e11937501358ada021bbdf7c47638d13c0ee982f2e79c
|
3 |
+
size 4689074
|
tokenizer_config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|