Spaces:
Running
Running
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team | |
# Copyright 2024 Bytedance Ltd. and/or its affiliates | |
# Based on: | |
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Optional, Tuple | |
import torch | |
from .flash_attention_utils import flash_attention_forward | |
try: | |
from transformers.models.qwen2_vl.modeling_qwen2_vl import ( | |
Qwen2VLAttention, | |
apply_multimodal_rotary_pos_emb, | |
repeat_kv, | |
) | |
from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessor | |
except ImportError: | |
pass | |
def get_rope_index( | |
processor: "Qwen2VLProcessor", | |
input_ids: torch.Tensor, | |
image_grid_thw: Optional[torch.Tensor] = None, | |
video_grid_thw: Optional[torch.Tensor] = None, | |
second_per_grid_ts: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
) -> torch.Tensor: | |
""" | |
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence. | |
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example. | |
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546 | |
""" | |
spatial_merge_size = processor.image_processor.merge_size | |
tokens_per_second = 2 | |
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") | |
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>") | |
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>") | |
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids) | |
position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) # (3, seqlen) | |
image_index, video_index = 0, 0 | |
input_ids = input_ids[attention_mask == 1] | |
image_nums, video_nums = 0, 0 | |
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id) | |
vision_tokens = input_ids[vision_start_indices + 1] | |
image_nums = (vision_tokens == image_token_id).sum() | |
video_nums = (vision_tokens == video_token_id).sum() | |
input_tokens = input_ids.tolist() | |
llm_pos_ids_list: list = [] | |
st = 0 | |
remain_images, remain_videos = image_nums, video_nums | |
for _ in range(image_nums + video_nums): | |
if image_token_id in input_tokens and remain_images > 0: | |
ed_image = input_tokens.index(image_token_id, st) | |
else: | |
ed_image = len(input_tokens) + 1 | |
if video_token_id in input_tokens and remain_videos > 0: | |
ed_video = input_tokens.index(video_token_id, st) | |
else: | |
ed_video = len(input_tokens) + 1 | |
if ed_image < ed_video: | |
t, h, w = ( | |
image_grid_thw[image_index][0], | |
image_grid_thw[image_index][1], | |
image_grid_thw[image_index][2], | |
) | |
second_per_grid_t = 0 | |
image_index += 1 | |
remain_images -= 1 | |
ed = ed_image | |
else: | |
t, h, w = ( | |
video_grid_thw[video_index][0], | |
video_grid_thw[video_index][1], | |
video_grid_thw[video_index][2], | |
) | |
if second_per_grid_ts is not None: | |
second_per_grid_t = second_per_grid_ts[video_index] | |
else: | |
second_per_grid_t = 1.0 | |
video_index += 1 | |
remain_videos -= 1 | |
ed = ed_video | |
llm_grid_t, llm_grid_h, llm_grid_w = ( | |
t.item(), | |
h.item() // spatial_merge_size, | |
w.item() // spatial_merge_size, | |
) | |
text_len = ed - st | |
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | |
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | |
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w) | |
t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten() | |
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() | |
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() | |
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) | |
st = ed + llm_grid_t * llm_grid_h * llm_grid_w | |
if st < len(input_tokens): | |
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | |
text_len = len(input_tokens) - st | |
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | |
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | |
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device) | |
else: | |
if attention_mask is not None: | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device) | |
else: | |
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1) | |
return position_ids | |
def qwen2_vl_attn_forward( | |
self: "Qwen2VLAttention", | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 | |
**kwargs, | |
) -> Tuple[torch.Tensor, None, None]: | |
bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size | |
query_states = self.q_proj(hidden_states) # (batch_size, seq_length / sp_size, num_heads * head_size) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
# Because the input can be padded, the absolute sequence length depends on the max position id. | |
if position_embeddings is None: | |
cos, sin = self.rotary_emb(value_states, position_ids) | |
else: | |
cos, sin = position_embeddings | |
query_states, key_states = apply_multimodal_rotary_pos_emb( | |
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] | |
) | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
dropout_rate = 0.0 if not self.training else self.attention_dropout | |
sliding_window = None | |
if ( | |
self.config.use_sliding_window | |
and getattr(self.config, "sliding_window", None) is not None | |
and self.layer_idx >= self.config.max_window_layers | |
): | |
sliding_window = self.config.sliding_window | |
attn_output, _ = flash_attention_forward( | |
self, | |
query_states, | |
key_states, | |
value_states, | |
attention_mask, | |
dropout=dropout_rate, | |
sliding_window=sliding_window, | |
position_ids=position_ids, # important: pass position ids | |
) # (batch_size, seq_length, num_head / sp_size, head_size) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() | |
attn_output = self.o_proj(attn_output) | |
return attn_output, None, None | |