kcz358 commited on
Commit
3d429c6
·
verified ·
1 Parent(s): 798c0f7

Upload modeling_aero.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_aero.py +391 -0
modeling_aero.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Aero model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from transformers import AutoConfig, AutoModel
27
+ from transformers.activations import ACT2FN
28
+ from transformers.generation import GenerationMixin
29
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM
32
+ from transformers.utils import logging
33
+
34
+ from .configuration_aero import AeroConfig
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ @dataclass
40
+ # Copied from transformers.models.llava_next_video.modeling_llava_next_video.LlavaNextVideoCausalLMOutputWithPast with LlavaNextVideo->LlavaOnevision
41
+ class AeroCausalLMOutputWithPast(ModelOutput):
42
+ """
43
+ Base class for Aero causal language model (or autoregressive) outputs.
44
+
45
+ Args:
46
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
47
+ Language modeling loss (for next-token prediction).
48
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
49
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
50
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
51
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
52
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
53
+
54
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
55
+ `past_key_values` input) to speed up sequential decoding.
56
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
57
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
58
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
59
+
60
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
61
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
62
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
63
+ sequence_length)`.
64
+
65
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
66
+ heads.
67
+ audio_hidden_states (`torch.FloatTensor`, *optional*):
68
+ A `torch.FloatTensor`.
69
+ audio_hidden_states of the model produced by the audio encoder and after projecting the last hidden state.
70
+
71
+ """
72
+
73
+ loss: Optional[torch.FloatTensor] = None
74
+ logits: torch.FloatTensor = None
75
+ past_key_values: Optional[List[torch.FloatTensor]] = None
76
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
77
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
78
+ audio_hidden_states: Optional[torch.FloatTensor] = None
79
+
80
+
81
+
82
+ class AeroAudioMultiModalProjector(nn.Module):
83
+ def __init__(self, config: AeroConfig):
84
+ super().__init__()
85
+ self.linear = nn.Linear(
86
+ config.audio_config.d_model, config.text_config.hidden_size, bias=True
87
+ )
88
+
89
+ def forward(self, audio_features):
90
+ hidden_states = self.linear(audio_features)
91
+ return hidden_states
92
+
93
+
94
+
95
+ class AeroPreTrainedModel(PreTrainedModel):
96
+ config_class = AeroConfig
97
+ base_model_prefix = "language_model"
98
+ supports_gradient_checkpointing = True
99
+ _skip_keys_device_placement = "past_key_values"
100
+ _supports_flash_attn_2 = True
101
+ _supports_cache_class = True
102
+ _supports_static_cache = (
103
+ False # Qwen2 doesn't but llava has no reasons to not support
104
+ )
105
+ _supports_quantized_cache = True
106
+ _supports_sdpa = True
107
+
108
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextPreTrainedModel._init_weights
109
+ def _init_weights(self, module):
110
+ # important: this ported version of LlavaNext isn't meant for training from scratch - only
111
+ # inference and fine-tuning - so the proper init weights code has been removed - the original codebase
112
+ # https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose
113
+ std = (
114
+ self.config.initializer_range
115
+ if hasattr(self.config, "initializer_range")
116
+ else self.config.text_config.initializer_range
117
+ )
118
+
119
+ if hasattr(module, "class_embedding"):
120
+ module.class_embedding.data.normal_(mean=0.0, std=std)
121
+
122
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
123
+ module.weight.data.normal_(mean=0.0, std=std)
124
+ if module.bias is not None:
125
+ module.bias.data.zero_()
126
+ elif isinstance(module, (nn.Linear, nn.Conv1d)):
127
+ module.weight.data.normal_(mean=0.0, std=std)
128
+ if module.bias is not None:
129
+ module.bias.data.zero_()
130
+ elif isinstance(module, nn.Embedding):
131
+ module.weight.data.normal_(mean=0.0, std=std)
132
+ if module.padding_idx is not None:
133
+ module.weight.data[module.padding_idx].zero_()
134
+
135
+
136
+ class AeroForConditionalGeneration(AeroPreTrainedModel, GenerationMixin):
137
+ def __init__(self, config: AeroConfig):
138
+ super().__init__(config)
139
+
140
+ self.audio_tower_type = config.audio_config.model_type
141
+ self.audio_tower = AutoModel.from_config(config.audio_config)
142
+ self.audio_modal_projector = AeroAudioMultiModalProjector(config)
143
+ self.vocab_size = config.text_config.vocab_size
144
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config)
145
+ self.post_init()
146
+
147
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings
148
+ def get_input_embeddings(self):
149
+ return self.language_model.get_input_embeddings()
150
+
151
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_input_embeddings
152
+ def set_input_embeddings(self, value):
153
+ self.language_model.set_input_embeddings(value)
154
+
155
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_output_embeddings
156
+ def get_output_embeddings(self):
157
+ return self.language_model.get_output_embeddings()
158
+
159
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_output_embeddings
160
+ def set_output_embeddings(self, new_embeddings):
161
+ self.language_model.set_output_embeddings(new_embeddings)
162
+
163
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.set_decoder
164
+ def set_decoder(self, decoder):
165
+ self.language_model.set_decoder(decoder)
166
+
167
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_decoder
168
+ def get_decoder(self):
169
+ return self.language_model.get_decoder()
170
+
171
+ # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.tie_weights
172
+ def tie_weights(self):
173
+ return self.language_model.tie_weights()
174
+
175
+ def prepare_inputs_for_qwen_audio_encoder(
176
+ self,
177
+ audio_values: torch.Tensor,
178
+ audio_attention_mask: torch.Tensor,
179
+ audio_feat_lengths: torch.FloatTensor,
180
+ audio_output_lengths: torch.FloatTensor,
181
+ ):
182
+ batch_size, _, max_mel_seq_len = audio_values.shape
183
+ max_seq_len = (max_mel_seq_len - 2) // 2 + 1
184
+ # Create a sequence tensor of shape (batch_size, max_seq_len)
185
+ seq_range = (
186
+ torch.arange(
187
+ 0,
188
+ max_seq_len,
189
+ dtype=audio_feat_lengths.dtype,
190
+ device=audio_feat_lengths.device,
191
+ )
192
+ .unsqueeze(0)
193
+ .expand(batch_size, max_seq_len)
194
+ )
195
+ lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len)
196
+ # Create mask
197
+ padding_mask = seq_range >= lengths_expand
198
+
199
+ audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
200
+ batch_size, 1, max_seq_len, max_seq_len
201
+ )
202
+ audio_attention_mask = audio_attention_mask_.to(
203
+ dtype=self.audio_tower.conv1.weight.dtype,
204
+ device=self.audio_tower.conv1.weight.device,
205
+ )
206
+ audio_attention_mask[audio_attention_mask_] = float("-inf")
207
+
208
+ inputs = {
209
+ "input_features": audio_values,
210
+ "attention_mask": audio_attention_mask,
211
+ }
212
+ return inputs
213
+
214
+ def prepare_scattered_audio_values(
215
+ self,
216
+ audio_features,
217
+ audio_output_lengths,
218
+ ):
219
+ # Audio feature is in (bs, max_seq_len, hidden_size)
220
+ # If directly masked scatter, the embed will be place one by one (order is incorret)
221
+ # We remove the padded values first
222
+ unpadded_audio_features = [
223
+ audio_feat[:audio_output_length]
224
+ for audio_feat, audio_output_length in zip(
225
+ audio_features, audio_output_lengths
226
+ )
227
+ ]
228
+ # Concat the audio features
229
+ # Should exactly have audio_mask.sum() values
230
+ unpadded_audio_features = torch.concatenate(unpadded_audio_features, dim=0)
231
+ return unpadded_audio_features
232
+
233
+ def forward(
234
+ self,
235
+ input_ids: torch.LongTensor = None,
236
+ audio_values: torch.FloatTensor = None,
237
+ audio_attention_mask: Optional[torch.Tensor] = None,
238
+ attention_mask: Optional[torch.Tensor] = None,
239
+ position_ids: Optional[torch.LongTensor] = None,
240
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
241
+ inputs_embeds: Optional[torch.FloatTensor] = None,
242
+ labels: Optional[torch.LongTensor] = None,
243
+ use_cache: Optional[bool] = None,
244
+ output_attentions: Optional[bool] = None,
245
+ output_hidden_states: Optional[bool] = None,
246
+ return_dict: Optional[bool] = None,
247
+ cache_position: Optional[torch.LongTensor] = None,
248
+ logits_to_keep: int = 0,
249
+ ) -> Union[Tuple, AeroCausalLMOutputWithPast]:
250
+ output_attentions = (
251
+ output_attentions
252
+ if output_attentions is not None
253
+ else self.config.output_attentions
254
+ )
255
+ output_hidden_states = (
256
+ output_hidden_states
257
+ if output_hidden_states is not None
258
+ else self.config.output_hidden_states
259
+ )
260
+ return_dict = (
261
+ return_dict if return_dict is not None else self.config.use_return_dict
262
+ )
263
+
264
+ if (input_ids is None) ^ (inputs_embeds is not None):
265
+ raise ValueError(
266
+ "You must specify exactly one of input_ids or inputs_embeds"
267
+ )
268
+
269
+ if inputs_embeds is None:
270
+ inputs_embeds = self.get_input_embeddings()(input_ids)
271
+
272
+ # Embed audio features
273
+ if audio_values is not None:
274
+ (
275
+ audio_feat_lengths,
276
+ audio_output_lengths,
277
+ ) = self.audio_tower._get_feat_extract_output_lengths(
278
+ audio_attention_mask.sum(-1)
279
+ )
280
+ inputs = self.prepare_inputs_for_qwen_audio_encoder(
281
+ audio_values=audio_values,
282
+ audio_attention_mask=audio_attention_mask,
283
+ audio_feat_lengths=audio_feat_lengths,
284
+ audio_output_lengths=audio_output_lengths,
285
+ )
286
+
287
+ audio_outputs = self.audio_tower(**inputs)
288
+ selected_audio_feature = audio_outputs.last_hidden_state
289
+ audio_features = self.audio_modal_projector(selected_audio_feature)
290
+ n_audio_tokens = (input_ids == self.config.audio_token_index).sum().item()
291
+ n_audio_features = audio_output_lengths.sum()
292
+ if n_audio_tokens != n_audio_features:
293
+ raise ValueError(
294
+ f"Audio features and image tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}"
295
+ )
296
+ audio_mask = (
297
+ (input_ids == self.config.audio_token_index)
298
+ .unsqueeze(-1)
299
+ .expand_as(inputs_embeds)
300
+ .to(inputs_embeds.device)
301
+ )
302
+ audio_features = audio_features.to(
303
+ inputs_embeds.device, inputs_embeds.dtype
304
+ )
305
+ audio_features = self.prepare_scattered_audio_values(
306
+ audio_features, audio_output_lengths
307
+ )
308
+ inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
309
+
310
+ outputs = self.language_model(
311
+ attention_mask=attention_mask,
312
+ position_ids=position_ids,
313
+ past_key_values=past_key_values,
314
+ inputs_embeds=inputs_embeds,
315
+ use_cache=use_cache,
316
+ output_attentions=output_attentions,
317
+ output_hidden_states=output_hidden_states,
318
+ return_dict=return_dict,
319
+ cache_position=cache_position,
320
+ logits_to_keep=logits_to_keep,
321
+ labels=labels,
322
+ )
323
+
324
+ logits = outputs[0]
325
+ loss = outputs.get("loss", None)
326
+ if labels is not None and loss is None:
327
+ # Shift so that tokens < n predict n
328
+ if attention_mask is not None:
329
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
330
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
331
+ shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(
332
+ logits.device
333
+ )
334
+ shift_logits = logits[..., :-1, :][
335
+ shift_attention_mask.to(logits.device) != 0
336
+ ].contiguous()
337
+ shift_labels = labels[..., 1:][
338
+ shift_attention_mask.to(labels.device) != 0
339
+ ].contiguous()
340
+ else:
341
+ shift_logits = logits[..., :-1, :].contiguous()
342
+ shift_labels = labels[..., 1:].contiguous()
343
+ # Flatten the tokens
344
+ loss_fct = nn.CrossEntropyLoss()
345
+ loss = loss_fct(
346
+ shift_logits.view(-1, shift_logits.size(-1)),
347
+ shift_labels.view(-1).to(shift_logits.device),
348
+ )
349
+
350
+ if not return_dict:
351
+ output = (logits,) + outputs[1:]
352
+ return (loss,) + output if loss is not None else output
353
+
354
+ return AeroCausalLMOutputWithPast(
355
+ loss=loss,
356
+ logits=logits,
357
+ past_key_values=outputs.past_key_values,
358
+ hidden_states=outputs.hidden_states,
359
+ attentions=outputs.attentions,
360
+ audio_hidden_states=audio_features if audio_values is not None else None,
361
+ )
362
+
363
+ def prepare_inputs_for_generation(
364
+ self,
365
+ input_ids,
366
+ past_key_values=None,
367
+ inputs_embeds=None,
368
+ attention_mask=None,
369
+ cache_position=None,
370
+ logits_to_keep=None,
371
+ audio_values=None,
372
+ audio_attention_mask=None,
373
+ **kwargs,
374
+ ):
375
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
376
+
377
+ model_inputs = self.language_model.prepare_inputs_for_generation(
378
+ input_ids,
379
+ past_key_values=past_key_values,
380
+ inputs_embeds=inputs_embeds,
381
+ attention_mask=attention_mask,
382
+ cache_position=cache_position,
383
+ logits_to_keep=logits_to_keep,
384
+ **kwargs,
385
+ )
386
+
387
+ if cache_position[0] == 0:
388
+ model_inputs["audio_values"] = audio_values
389
+ model_inputs["audio_attention_mask"] = audio_attention_mask
390
+
391
+ return model_inputs