diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..41d2271826c4932cc927e5e299c8f4146eda704a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+examples/examples0.jpg filter=lfs diff=lfs merge=lfs -text
+examples/examples1.jpg filter=lfs diff=lfs merge=lfs -text
+examples/examples2.png filter=lfs diff=lfs merge=lfs -text
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7cc3557d13dcb8bbcba160be2af18942e12d163
--- /dev/null
+++ b/app.py
@@ -0,0 +1,195 @@
+import os
+import argparse
+import numpy as np
+from PIL import Image
+import torch
+import torchvision.transforms as T
+from transformers import AutoTokenizer
+import gradio as gr
+from resnet50 import build_model
+from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50
+from utils import IMAGENET_MEAN, IMAGENET_STD
+from internvl.train.dataset import dynamic_preprocess
+from internvl.model.internvl_chat import InternVLChatModel
+
+# 模型配置
+CHECKPOINTS = {
+ "TokenOCR-4096-English-seg": "/path/to/TokenOCR_4096_English_seg",
+ "TokenOCR-2048-Bilingual-seg": "/path/to/TokenOCR_2048_Binlinual_seg",
+ "R50":"model/checkpoint.pth",
+ "R50_siglip": "/path/to/R50_siglip_checkpoint.pth"
+}
+
+# 全局变量
+current_vis = []
+current_bpe = []
+current_index = 0
+
+def load_model(check_type):
+ device = torch.device("cpu")
+
+ if check_type == 'R50':
+ tokenizer = load_tokenizer('tokenizer_path')
+ model = build_model(argparse.Namespace()).eval()
+ model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model'])
+ transform = build_transform_R50(normalize_type='imagenet')
+
+ elif check_type == 'R50_siglip':
+ tokenizer = load_tokenizer('tokenizer_path')
+ model = build_model(argparse.Namespace()).eval()
+ model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model'])
+ transform = build_transform_R50(normalize_type='imagenet')
+
+ elif 'TokenOCR' in check_type:
+ model_path = CHECKPOINTS[check_type]
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
+ model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval()
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB')),
+ T.Resize((224, 224)),
+ T.ToTensor(),
+ T.Normalize(IMAGENET_MEAN, IMAGENET_STD)
+ ])
+
+ return model.to(device), tokenizer, transform, device
+
+def process_image(model, tokenizer, transform, device, check_type, image, text):
+ global current_vis, current_bpe
+ src_size = image.size
+ if 'TokenOCR' in check_type:
+ images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
+ image_size=model.config.force_image_size,
+ use_thumbnail=model.config.use_thumbnail,
+ return_ratio=True)
+ pixel_values = torch.stack([transform(img) for img in images]).to(device)
+ else:
+ pixel_values = torch.stack([transform(image)]).to(device)
+ target_ratio = (1, 1)
+
+ # 文本处理
+ text += ' '
+ input_ids = tokenizer(text)['input_ids'][1:]
+ input_ids = torch.tensor(input_ids, device=device)
+
+ # 获取嵌入
+ with torch.no_grad():
+ if 'R50' in check_type:
+ text_embeds = model.language_embedding(input_ids)
+ else:
+ text_embeds = model.tok_embeddings(input_ids)
+
+ vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(device))
+ vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type)
+
+ # 计算相似度
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
+ vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True)
+ similarity = text_embeds @ vit_embeds.T
+ resized_size = size1 if size1 is not None else size2
+
+ # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192
+ # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944
+ # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912
+
+
+ # 生成可视化
+ attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1])
+ # attn_map = similarity.reshape(len(text_embeds), *target_ratio)
+ all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids]
+ current_vis = generate_similiarity_map([image], attn_map,
+ [tokenizer.decode([i]) for i in input_ids],
+ [], target_ratio, src_size)
+
+ current_bpe = [tokenizer.decode([i]) for i in input_ids]
+ # current_bpe[-1] = 'Input text'
+ current_bpe[-1] = text
+ return image, current_vis[0], current_bpe[0]
+
+# 事件处理函数
+def update_index(change):
+ global current_index
+ current_index = max(0, min(len(current_vis) - 1, current_index + change))
+ return current_vis[current_index], format_bpe_display(current_bpe[current_index])
+
+def format_bpe_display(bpe):
+ # 使用HTML标签来设置字体大小、颜色,加粗,并居中
+ return f"
Current BPE: {bpe}
"
+
+# Gradio界面
+with gr.Blocks(title="BPE Visualization Demo") as demo:
+ gr.Markdown("## BPE Visualization Demo - TokenOCR基座模型能力可视化")
+
+ with gr.Row():
+ with gr.Column(scale=0.5):
+ model_type = gr.Dropdown(
+ choices=["TokenOCR-4096-English-seg", "TokenOCR-2048-Bilingual-seg", "R50", "R50_siglip"],
+ label="Select model type",
+ value="R50" # 设置默认值为第一个选项
+ )
+ image_input = gr.Image(label="Upload images", type="pil")
+ text_input = gr.Textbox(label="Input text")
+
+ run_btn = gr.Button("RUN")
+
+ gr.Examples(
+ examples=[
+ [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"],
+ [os.path.join("examples", "examples1.jpg"), "Refreshers"],
+ [os.path.join("examples", "examples2.png"), "Vision Transformer"]
+ ],
+ inputs=[image_input, text_input],
+ label="Sample input"
+ )
+
+ with gr.Column(scale=2):
+ gr.Markdown("If the input text is not included in the image, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.
")
+
+ with gr.Row():
+ orig_img = gr.Image(label="Original picture", interactive=False)
+ heatmap = gr.Image(label="BPE visualization", interactive=False)
+
+ with gr.Row() as controls:
+ prev_btn = gr.Button("⬅ Last", visible=False)
+ index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False)
+ next_btn = gr.Button("⮕ Next", visible=False)
+
+ bpe_display = gr.Markdown("Current BPE: ", visible=False)
+
+ # 事件处理
+ def on_run_clicked(model_type, image, text):
+ global current_vis, current_bpe, current_index
+ current_index = 0 # Reset index when new image is processed
+ image, vis, bpe = process_image(*load_model(model_type), model_type, image, text)
+ # Update the slider range and set value to 0
+ slider_max_val = len(current_bpe) - 1
+ bpe_text = format_bpe_display(bpe)
+ return image, vis, bpe_text, slider_max_val
+
+ run_btn.click(
+ on_run_clicked,
+ inputs=[model_type, image_input, text_input],
+ outputs=[orig_img, heatmap, bpe_display, index_slider],
+ ).then(
+ lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)),
+ inputs=index_slider,
+ outputs=[prev_btn, index_slider, next_btn, bpe_display],
+ )
+
+ prev_btn.click(
+ lambda: (*update_index(-1), current_index),
+ outputs=[heatmap, bpe_display, index_slider]
+ )
+
+ next_btn.click(
+ lambda: (*update_index(1), current_index),
+ outputs=[heatmap, bpe_display, index_slider]
+ )
+
+ index_slider.change(
+ lambda x: (current_vis[x], format_bpe_display(current_bpe[x])),
+ inputs=index_slider,
+ outputs=[heatmap, bpe_display]
+ )
+
+if __name__ == "__main__":
+ demo.launch()
\ No newline at end of file
diff --git a/examples/examples0.jpg b/examples/examples0.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..89ce15bfdc800ec5851f17369f7e1bbfc8e2f8a6
--- /dev/null
+++ b/examples/examples0.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:afdea4f2f6523d2a8f9a10ed010ea57e257908fa284485b224c276310f723121
+size 423018
diff --git a/examples/examples1.jpg b/examples/examples1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fdd23474194746ea432b850ecd53e6f291f0818e
--- /dev/null
+++ b/examples/examples1.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:347fe59f49c31c93f78a913c891eaeeeace59851f69dfb2c2631dfc7dd6c3886
+size 571030
diff --git a/examples/examples2.png b/examples/examples2.png
new file mode 100644
index 0000000000000000000000000000000000000000..7fb40590b412ee234aa0298c27e7c5bfa8a13fd1
--- /dev/null
+++ b/examples/examples2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8066fab0957545e840369338b2693136f60cecf61a6eb85a9c650ad1e3f69e27
+size 1470171
diff --git a/internvl/__pycache__/conversation.cpython-310.pyc b/internvl/__pycache__/conversation.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfbde9fd800528506c3b2af19379db5842b6a883
Binary files /dev/null and b/internvl/__pycache__/conversation.cpython-310.pyc differ
diff --git a/internvl/__pycache__/conversation.cpython-39.pyc b/internvl/__pycache__/conversation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de0bcb38f21f0c673cbe4b936c46203f6e941887
Binary files /dev/null and b/internvl/__pycache__/conversation.cpython-39.pyc differ
diff --git a/internvl/__pycache__/dist_utils.cpython-39.pyc b/internvl/__pycache__/dist_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8aef188afd67ddb6ecd1cf8717269190b08a8d1f
Binary files /dev/null and b/internvl/__pycache__/dist_utils.cpython-39.pyc differ
diff --git a/internvl/conversation.py b/internvl/conversation.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf9e80bebbd5342f313fbef61c2b21194157a0db
--- /dev/null
+++ b/internvl/conversation.py
@@ -0,0 +1,402 @@
+"""
+Conversation prompt templates.
+
+We kindly request that you import fastchat instead of copying this file if you wish to use it.
+If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
+"""
+
+import dataclasses
+from enum import IntEnum, auto
+from typing import Any, Dict, List, Tuple, Union
+
+
+class SeparatorStyle(IntEnum):
+ """Separator styles."""
+
+ ADD_COLON_SINGLE = auto()
+ ADD_COLON_TWO = auto()
+ ADD_COLON_SPACE_SINGLE = auto()
+ NO_COLON_SINGLE = auto()
+ NO_COLON_TWO = auto()
+ ADD_NEW_LINE_SINGLE = auto()
+ LLAMA2 = auto()
+ CHATGLM = auto()
+ CHATML = auto()
+ CHATINTERN = auto()
+ DOLLY = auto()
+ RWKV = auto()
+ PHOENIX = auto()
+ ROBIN = auto()
+ FALCON_CHAT = auto()
+ CHATGLM3 = auto()
+ INTERNVL_ZH = auto()
+ MPT = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that manages prompt templates and keeps all conversation history."""
+
+ # The name of this template
+ name: str
+ # The template of the system prompt
+ system_template: str = '{system_message}'
+ # The system message
+ system_message: str = ''
+ # The names of two roles
+ roles: Tuple[str] = ('USER', 'ASSISTANT')
+ # All messages. Each item is (role, message).
+ messages: List[List[str]] = ()
+ # The number of few shot examples
+ offset: int = 0
+ # The separator style and configurations
+ sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
+ sep: str = '\n'
+ sep2: str = None
+ # Stop criteria (the default one is EOS token)
+ stop_str: Union[str, List[str]] = None
+ # Stops generation if meeting any token in this list
+ stop_token_ids: List[int] = None
+
+ def get_prompt(self) -> str:
+ """Get the prompt for generation."""
+ system_prompt = self.system_template.format(system_message=self.system_message)
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + message + self.sep
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
+ seps = [self.sep, self.sep2]
+ ret = system_prompt + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ': ' + message + seps[i % 2]
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + message + self.sep
+ else:
+ ret += role + ': ' # must be end with a space
+ return ret
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
+ ret = '' if system_prompt == '' else system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + '\n' + message + self.sep
+ else:
+ ret += role + '\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
+ ret = system_prompt
+ for role, message in self.messages:
+ if message:
+ ret += role + message + self.sep
+ else:
+ ret += role
+ return ret
+ elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
+ seps = [self.sep, self.sep2]
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + message + seps[i % 2]
+ else:
+ ret += role
+ return ret
+ elif self.sep_style == SeparatorStyle.RWKV:
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += (
+ role
+ + ': '
+ + message.replace('\r\n', '\n').replace('\n\n', '\n')
+ )
+ ret += '\n\n'
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.LLAMA2:
+ seps = [self.sep, self.sep2]
+ if self.system_message:
+ ret = system_prompt
+ else:
+ ret = '[INST] '
+ for i, (role, message) in enumerate(self.messages):
+ tag = self.roles[i % 2]
+ if message:
+ if i == 0:
+ ret += message + ' '
+ else:
+ ret += tag + ' ' + message + seps[i % 2]
+ else:
+ ret += tag
+ return ret
+ elif self.sep_style == SeparatorStyle.CHATGLM:
+ # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
+ # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
+ round_add_n = 1 if self.name == 'chatglm2' else 0
+ if system_prompt:
+ ret = system_prompt + self.sep
+ else:
+ ret = ''
+
+ for i, (role, message) in enumerate(self.messages):
+ if i % 2 == 0:
+ ret += f'[Round {i//2 + round_add_n}]{self.sep}'
+
+ if message:
+ ret += f'{role}:{message}{self.sep}'
+ else:
+ ret += f'{role}:'
+ return ret
+ elif self.sep_style == SeparatorStyle.CHATML:
+ ret = '' if system_prompt == '' else system_prompt + self.sep + '\n'
+ for role, message in self.messages:
+ if message:
+ ret += role + '\n' + message + self.sep + '\n'
+ else:
+ ret += role + '\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.CHATGLM3:
+ ret = ''
+ if self.system_message:
+ ret += system_prompt
+ for role, message in self.messages:
+ if message:
+ ret += role + '\n' + ' ' + message
+ else:
+ ret += role
+ return ret
+ elif self.sep_style == SeparatorStyle.CHATINTERN:
+ # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
+ seps = [self.sep, self.sep2]
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ # if i % 2 == 0:
+ # ret += ""
+ if message:
+ ret += role + ':' + message + seps[i % 2] + '\n'
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.DOLLY:
+ seps = [self.sep, self.sep2]
+ ret = system_prompt
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ':\n' + message + seps[i % 2]
+ if i % 2 == 1:
+ ret += '\n\n'
+ else:
+ ret += role + ':\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.PHOENIX:
+ ret = system_prompt
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + '' + message + ''
+ else:
+ ret += role + ': ' + ''
+ return ret
+ elif self.sep_style == SeparatorStyle.ROBIN:
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ':\n' + message + self.sep
+ else:
+ ret += role + ':\n'
+ return ret
+ elif self.sep_style == SeparatorStyle.FALCON_CHAT:
+ ret = ''
+ if self.system_message:
+ ret += system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ ret += role + ': ' + message + self.sep
+ else:
+ ret += role + ':'
+
+ return ret
+ elif self.sep_style == SeparatorStyle.INTERNVL_ZH:
+ seps = [self.sep2, self.sep]
+ ret = self.system_message + seps[0]
+ for i, (role, message) in enumerate(self.messages):
+ if message:
+ ret += role + ': ' + message + seps[i % 2]
+ else:
+ ret += role + ':'
+ return ret
+ elif self.sep_style == SeparatorStyle.MPT:
+ ret = system_prompt + self.sep
+ for role, message in self.messages:
+ if message:
+ if type(message) is tuple:
+ message, _, _ = message
+ ret += role + message + self.sep
+ else:
+ ret += role
+ return ret
+ else:
+ raise ValueError(f'Invalid style: {self.sep_style}')
+
+ def set_system_message(self, system_message: str):
+ """Set the system message."""
+ self.system_message = system_message
+
+ def append_message(self, role: str, message: str):
+ """Append a new message."""
+ self.messages.append([role, message])
+
+ def update_last_message(self, message: str):
+ """Update the last output.
+
+ The last message is typically set to be None when constructing the prompt,
+ so we need to update it in-place after getting the response from a model.
+ """
+ self.messages[-1][1] = message
+
+ def to_gradio_chatbot(self):
+ """Convert the conversation to gradio chatbot format."""
+ ret = []
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ ret.append([msg, None])
+ else:
+ ret[-1][-1] = msg
+ return ret
+
+ def to_openai_api_messages(self):
+ """Convert the conversation to OpenAI chat completion format."""
+ ret = [{'role': 'system', 'content': self.system_message}]
+
+ for i, (_, msg) in enumerate(self.messages[self.offset :]):
+ if i % 2 == 0:
+ ret.append({'role': 'user', 'content': msg})
+ else:
+ if msg is not None:
+ ret.append({'role': 'assistant', 'content': msg})
+ return ret
+
+ def copy(self):
+ return Conversation(
+ name=self.name,
+ system_template=self.system_template,
+ system_message=self.system_message,
+ roles=self.roles,
+ messages=[[x, y] for x, y in self.messages],
+ offset=self.offset,
+ sep_style=self.sep_style,
+ sep=self.sep,
+ sep2=self.sep2,
+ stop_str=self.stop_str,
+ stop_token_ids=self.stop_token_ids,
+ )
+
+ def dict(self):
+ return {
+ 'template_name': self.name,
+ 'system_message': self.system_message,
+ 'roles': self.roles,
+ 'messages': self.messages,
+ 'offset': self.offset,
+ }
+
+
+# A global registry for all conversation templates
+conv_templates: Dict[str, Conversation] = {}
+
+
+def register_conv_template(template: Conversation, override: bool = False):
+ """Register a new conversation template."""
+ if not override:
+ assert (
+ template.name not in conv_templates
+ ), f'{template.name} has been registered.'
+
+ conv_templates[template.name] = template
+
+
+def get_conv_template(name: str) -> Conversation:
+ """Get a conversation template."""
+ return conv_templates[name].copy()
+
+
+# InternVL-Chat-V1-1 template
+register_conv_template(
+ Conversation(
+ name='internvl_zh',
+ system_template='',
+ roles=('', ''),
+ sep_style=SeparatorStyle.INTERNVL_ZH,
+ sep='',
+ sep2=' ',
+ )
+)
+
+
+# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference
+# is that during training, the preprocessing function for the Hermes-2 template doesn't add
+# at the beginning of the tokenized sequence, while the internlm2-chat template does.
+# Therefore, they are completely equivalent during inference.
+register_conv_template(
+ Conversation(
+ name='Hermes-2',
+ system_template='<|im_start|>system\n{system_message}',
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
+ sep_style=SeparatorStyle.MPT,
+ sep='<|im_end|>',
+ stop_str='<|endoftext|>',
+ )
+)
+
+
+register_conv_template(
+ Conversation(
+ name='internlm2-chat',
+ system_template='<|im_start|>system\n{system_message}',
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
+ sep_style=SeparatorStyle.MPT,
+ sep='<|im_end|>',
+ )
+)
+
+
+register_conv_template(
+ Conversation(
+ name='phi3-chat',
+ system_template='<|system|>\n{system_message}',
+ # note: The new system prompt was not used here to avoid changes in benchmark performance.
+ # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
+ system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。',
+ roles=('<|user|>\n', '<|assistant|>\n'),
+ sep_style=SeparatorStyle.MPT,
+ sep='<|end|>',
+ )
+)
+
+
+register_conv_template(
+ Conversation(
+ name='internvl2_5',
+ system_template='<|im_start|>system\n{system_message}',
+ system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
+ roles=('<|im_start|>user\n', '<|im_start|>assistant\n'),
+ sep_style=SeparatorStyle.MPT,
+ sep='<|im_end|>\n',
+ )
+)
diff --git a/internvl/dist_utils.py b/internvl/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eb8ae27731968e1acc11134b6c204b6d3c39afa
--- /dev/null
+++ b/internvl/dist_utils.py
@@ -0,0 +1,104 @@
+import os
+import socket
+import subprocess
+from datetime import timedelta
+
+import deepspeed
+import torch
+import torch.multiprocessing as mp
+from torch import distributed as dist
+
+timeout = timedelta(minutes=60)
+
+
+def _find_free_port():
+ # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(('', 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+
+def _is_free_port(port):
+ ips = socket.gethostbyname_ex(socket.gethostname())[-1]
+ ips.append('localhost')
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ return all(s.connect_ex((ip, port)) != 0 for ip in ips)
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'mpi':
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ # TODO: use local_rank instead of rank % num_gpus
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ # dist.init_process_group(backend=backend, **kwargs)
+ deepspeed.init_distributed(dist_backend=backend)
+
+
+def _init_dist_mpi(backend, **kwargs):
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ torch.cuda.set_device(local_rank)
+ if 'MASTER_PORT' not in os.environ:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ if 'MASTER_ADDR' not in os.environ:
+ raise KeyError('The environment variable MASTER_ADDR is not set')
+ os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
+ os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # if torch.distributed default port(29500) is available
+ # then use it, else find a free port
+ if _is_free_port(29500):
+ os.environ['MASTER_PORT'] = '29500'
+ else:
+ os.environ['MASTER_PORT'] = str(_find_free_port())
+ # use MASTER_ADDR in the environment variable if it already exists
+ if 'MASTER_ADDR' not in os.environ:
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ # dist.init_process_group(backend=backend, timeout=timeout)
+ deepspeed.init_distributed(dist_backend=backend)
diff --git a/internvl/model/__init__.py b/internvl/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c006f6d39c59b0d98ad820994ba15b09bf96539b
--- /dev/null
+++ b/internvl/model/__init__.py
@@ -0,0 +1,66 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import math
+
+import torch
+from internvl.model.internvl_chat import InternVLChatConfig, InternVLChatModel
+from transformers import AutoTokenizer
+
+
+def split_model(num_layers, vit_alpha=0.5):
+ device_map = {}
+ world_size = torch.cuda.device_count()
+ # Since the first GPU will be used for ViT, treat it as half a GPU.
+ num_layers_per_gpu = math.ceil(num_layers / (world_size - vit_alpha))
+ num_layers_per_gpu = [num_layers_per_gpu] * world_size
+ num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * (1 - vit_alpha))
+ layer_cnt = 0
+ for i, num_layer in enumerate(num_layers_per_gpu):
+ for j in range(num_layer):
+ device_map[f'language_model.model.layers.{layer_cnt}'] = i
+ layer_cnt += 1
+ device_map['vision_model'] = 0
+ device_map['mlp1'] = 0
+ device_map['language_model.model.tok_embeddings'] = 0
+ device_map['language_model.model.embed_tokens'] = 0
+ device_map['language_model.output'] = 0
+ device_map['language_model.model.norm'] = 0
+ device_map['language_model.lm_head'] = 0
+ device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
+
+ return device_map
+
+
+def load_model_and_tokenizer(args):
+ if args.auto:
+ config = InternVLChatConfig.from_pretrained(args.checkpoint)
+ num_hidden_layers = config.llm_config.num_hidden_layers
+ device_map = split_model(num_hidden_layers)
+ kwargs = {'device_map': device_map} if args.auto else {}
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
+ model = InternVLChatModel.from_pretrained(
+ args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
+ load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval()
+ if not args.load_in_8bit and not args.load_in_4bit and not args.auto:
+ model = model.cuda()
+ return model, tokenizer
+
+def load_model_and_tokenizer_customed(args):
+ if args.auto:
+ config = InternVLChatConfig.from_pretrained(args.checkpoint)
+ num_hidden_layers = config.llm_config.num_hidden_layers
+ device_map = split_model(num_hidden_layers)
+ kwargs = {'device_map': device_map} if args.auto else {}
+ tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
+ model = InternVLChatModel.from_pretrained(
+ args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
+ load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval()
+ if not args.load_in_8bit and not args.load_in_4bit and not args.auto:
+ del model.language_model.model.layers
+ del model.language_model.output
+ return model, tokenizer
+
diff --git a/internvl/model/__pycache__/__init__.cpython-310.pyc b/internvl/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14986ed766a1750e26dd590cb173849b754a6b2f
Binary files /dev/null and b/internvl/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/internvl/model/__pycache__/__init__.cpython-39.pyc b/internvl/model/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4ffe980293aabb7ab792d4ad980f774f0771484
Binary files /dev/null and b/internvl/model/__pycache__/__init__.cpython-39.pyc differ
diff --git a/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-310.pyc b/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0412df00467d99bdd04fa62d05119a3205298158
Binary files /dev/null and b/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-310.pyc differ
diff --git a/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-39.pyc b/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4b5eace8fe3afa226523803b99c244bdb24ca30c
Binary files /dev/null and b/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-39.pyc differ
diff --git a/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-310.pyc b/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..429e50e05cfa3d5b1206e6070808520cef3fa7e0
Binary files /dev/null and b/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-310.pyc differ
diff --git a/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-39.pyc b/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5a5e246572a6541734175afe216efe2f193aee3
Binary files /dev/null and b/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-39.pyc differ
diff --git a/internvl/model/internlm2/configuration_internlm2.py b/internvl/model/internlm2/configuration_internlm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..282b13b1e2066ecc074ecae87b35a19d251f0ed7
--- /dev/null
+++ b/internvl/model/internlm2/configuration_internlm2.py
@@ -0,0 +1,150 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/configuration_llama.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" InternLM2 model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
+
+
+# Modified from transformers.model.llama.configuration_llama.LlamaConfig
+class InternLM2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
+ an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
+ configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`InternLM2Model`]
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 11008):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ Example:
+
+ """
+ model_type = 'internlm2'
+ _auto_class = 'AutoConfig'
+
+ def __init__( # pylint: disable=W0102
+ self,
+ vocab_size=103168,
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ hidden_act='silu',
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ rms_norm_eps=1e-6,
+ use_cache=True,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ tie_word_embeddings=False,
+ bias=True,
+ rope_theta=10000,
+ rope_scaling=None,
+ attn_implementation='eager',
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.bias = bias
+
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+
+ self.hidden_act = hidden_act
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+
+ self.attn_implementation = attn_implementation
+ if self.attn_implementation is None:
+ self.attn_implementation = 'eager'
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
+ raise ValueError(
+ '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
+ f'got {self.rope_scaling}'
+ )
+ rope_scaling_type = self.rope_scaling.get('type', None)
+ rope_scaling_factor = self.rope_scaling.get('factor', None)
+ if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
+ raise ValueError(
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
+ )
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
+ raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
diff --git a/internvl/model/internlm2/modeling_internlm2.py b/internvl/model/internlm2/modeling_internlm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..569513dffad6a2bce63c26d7463f90fbcc289b2c
--- /dev/null
+++ b/internvl/model/internlm2/modeling_internlm2.py
@@ -0,0 +1,1429 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/modeling_llama.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch InternLM2 model."""
+import math
+import queue
+import threading
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from einops import rearrange
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (add_start_docstrings,
+ add_start_docstrings_to_model_forward, logging,
+ replace_return_docstrings)
+
+try:
+ from transformers.generation.streamers import BaseStreamer
+except: # noqa # pylint: disable=bare-except
+ BaseStreamer = None
+
+from .configuration_internlm2 import InternLM2Config
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = 'InternLM2Config'
+
+flash_attn_func, flash_attn_varlen_func = None, None
+pad_input, index_first_axis, unpad_input = None, None, None
+try:
+ from flash_attn import flash_attn_func as _flash_attn_func
+ from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
+ from flash_attn.bert_padding import index_first_axis as _index_first_axis
+ from flash_attn.bert_padding import pad_input as _pad_input
+ from flash_attn.bert_padding import unpad_input as _unpad_input
+
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
+ has_flash_attn = True
+except:
+ has_flash_attn = False
+
+
+def _import_flash_attn():
+ global flash_attn_func, flash_attn_varlen_func
+ global pad_input, index_first_axis, unpad_input
+ try:
+ from flash_attn import flash_attn_func as _flash_attn_func
+ from flash_attn import \
+ flash_attn_varlen_func as _flash_attn_varlen_func
+ from flash_attn.bert_padding import \
+ index_first_axis as _index_first_axis
+ from flash_attn.bert_padding import pad_input as _pad_input
+ from flash_attn.bert_padding import unpad_input as _unpad_input
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
+ except ImportError:
+ raise ImportError('flash_attn is not installed.')
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
+class InternLM2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ InternLM2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+try:
+ from functools import partial
+
+ from apex.normalization import FusedRMSNorm
+ InternLM2RMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
+ print('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternLM2RMSNorm')
+except ImportError:
+ # using the normal LlamaRMSNorm
+ pass
+except Exception:
+ print('discovered apex but it failed to load, falling back to InternLM2RMSNorm')
+ pass
+
+
+# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
+class InternLM2RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
+
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+
+
+# Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
+class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
+ """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
+
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
+ self.scaling_factor = scaling_factor
+ super().__init__(dim, max_position_embeddings, base, device)
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
+ t = t / self.scaling_factor
+
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
+
+
+# Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
+class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
+ """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
+ """
+
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
+ self.scaling_factor = scaling_factor
+ super().__init__(dim, max_position_embeddings, base, device)
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+
+ if seq_len > self.max_position_embeddings:
+ base = self.base * (
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
+ ) ** (self.dim / (self.dim - 2))
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
+
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
+
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
+
+
+# Copied from transformers.model.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors."""
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class InternLM2MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x))
+
+ return down_proj
+
+
+# Copied from transformers.model.llama.modeling_llama.repeat_kv
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+# Modified from transformers.model.llama.modeling_llama.LlamaAttention
+class InternLM2Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: InternLM2Config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
+ f' and `num_heads`: {self.num_heads}).'
+ )
+
+ self.wqkv = nn.Linear(
+ self.hidden_size,
+ (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
+ bias=config.bias,
+ )
+
+ self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.config.rope_scaling is None:
+ self.rotary_emb = InternLM2RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.config.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling['type']
+ scaling_factor = self.config.rope_scaling['factor']
+ if scaling_type == 'dynamic':
+ self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.config.rope_theta,
+ scaling_factor=scaling_factor,
+ )
+ elif scaling_type == 'linear':
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.config.rope_theta,
+ scaling_factor=scaling_factor,
+ )
+ else:
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
+ return self.rotary_emb
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if 'padding_mask' in kwargs:
+ warnings.warn(
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
+ 'Please make sure use `attention_mask` instead.`'
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
+ f' {attn_weights.size()}'
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
+ f' {attn_output.size()}'
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.wo(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
+class InternLM2FlashAttention2(InternLM2Attention):
+ """
+ InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # InternLM2FlashAttention2 attention does not support output_attentions
+ if 'padding_mask' in kwargs:
+ warnings.warn(
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
+ 'Please make sure use `attention_mask` instead.`'
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop('padding_mask')
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.wqkv(hidden_states)
+
+ qkv_states = rearrange(
+ qkv_states,
+ 'b q (h gs d) -> b q h gs d',
+ gs=2 + self.num_key_value_groups,
+ d=self.head_dim,
+ )
+
+ query_states = qkv_states[..., : self.num_key_value_groups, :]
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
+ key_states = qkv_states[..., -2, :]
+ value_states = qkv_states[..., -1, :]
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = self._flash_attention_forward(
+ query_states, key_states, value_states, attention_mask, q_len
+ )
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.wo(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ # Contains at least one padding token in the sequence
+ causal = self.is_causal and query_length != 1
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
+ )
+
+ return attn_output
+
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q.to(torch.int64),
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+INTERNLM2_ATTENTION_CLASSES = {
+ 'eager': InternLM2Attention,
+ 'flash_attention_2': InternLM2FlashAttention2,
+}
+
+
+# Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
+class InternLM2DecoderLayer(nn.Module):
+ def __init__(self, config: InternLM2Config):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
+
+ self.feed_forward = InternLM2MLP(config)
+ self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+ if 'padding_mask' in kwargs:
+ warnings.warn(
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
+ 'Please make sure use `attention_mask` instead.`'
+ )
+
+ residual = hidden_states
+
+ hidden_states = self.attention_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.attention(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.ffn_norm(hidden_states)
+ hidden_states = self.feed_forward(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+InternLM2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`InternLM2Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
+@add_start_docstrings(
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2PreTrainedModel(PreTrainedModel):
+ config_class = InternLM2Config
+ base_model_prefix = 'model'
+ supports_gradient_checkpointing = True
+ _no_split_modules = ['InternLM2DecoderLayer']
+ _skip_keys_device_placement = 'past_key_values'
+ _supports_flash_attn_2 = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+InternLM2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
+ when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+# Modified from transformers.model.llama.modeling_llama.LlamaModel
+@add_start_docstrings(
+ 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.',
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2Model(InternLM2PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`]
+
+ Args:
+ config: InternLM2Config
+ """
+
+ _auto_class = 'AutoModel'
+
+ def __init__(self, config: InternLM2Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+ self.config = config
+ if not has_flash_attn:
+ self.config.attn_implementation = 'eager'
+ print('Warning: Flash attention is not available, using eager attention instead.')
+
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+
+ self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.tok_embeddings = value
+
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.attn_implementation == 'flash_attention_2':
+ _import_flash_attn()
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError('You have to specify either input_ids or inputs_embeds')
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.tok_embeddings(input_ids)
+
+ if self.config.attn_implementation == 'flash_attention_2':
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+# Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
+class InternLM2ForCausalLM(InternLM2PreTrainedModel):
+ _auto_class = 'AutoModelForCausalLM'
+
+ _tied_weights_keys = ['output.weight']
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = InternLM2Model(config)
+ self.vocab_size = config.vocab_size
+ self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.model.tok_embeddings = value
+
+ def get_output_embeddings(self):
+ return self.output
+
+ def set_output_embeddings(self, new_embeddings):
+ self.output = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, InternLM2ForCausalLM
+
+ >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.output(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ output = CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ output['logits'] = output['logits'].to(device)
+ return output
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+
+ position_ids = kwargs.get('position_ids', None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1]:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {'inputs_embeds': inputs_embeds}
+ else:
+ model_inputs = {'input_ids': input_ids}
+
+ model_inputs.update(
+ {
+ 'position_ids': position_ids,
+ 'past_key_values': past_key_values,
+ 'use_cache': kwargs.get('use_cache'),
+ 'attention_mask': attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=''):
+ if tokenizer.add_bos_token:
+ prompt = ''
+ else:
+ prompt = tokenizer.bos_token
+ if meta_instruction:
+ prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n"""
+ for record in history:
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
+ return tokenizer([prompt], return_tensors='pt')
+
+ @torch.no_grad()
+ def chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = [],
+ streamer: Optional[BaseStreamer] = None,
+ max_new_tokens: int = 1024,
+ do_sample: bool = True,
+ temperature: float = 0.8,
+ top_p: float = 0.8,
+ meta_instruction: str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
+ '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n'
+ '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.',
+ **kwargs,
+ ):
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
+ inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]]
+ outputs = self.generate(
+ **inputs,
+ streamer=streamer,
+ max_new_tokens=max_new_tokens,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):]
+ response = tokenizer.decode(outputs, skip_special_tokens=True)
+ response = response.split('<|im_end|>')[0]
+ history = history + [(query, response)]
+ return response, history
+
+ @torch.no_grad()
+ def stream_chat(
+ self,
+ tokenizer,
+ query: str,
+ history: List[Tuple[str, str]] = [],
+ max_new_tokens: int = 1024,
+ do_sample: bool = True,
+ temperature: float = 0.8,
+ top_p: float = 0.8,
+ **kwargs,
+ ):
+ """
+ Return a generator in format: (response, history)
+ Eg.
+ ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
+ ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
+ """
+ if BaseStreamer is None:
+ raise ModuleNotFoundError(
+ 'The version of `transformers` is too low. Please make sure '
+ 'that you have installed `transformers>=4.28.0`.'
+ )
+
+ response_queue = queue.Queue(maxsize=20)
+
+ class ChatStreamer(BaseStreamer):
+ def __init__(self, tokenizer) -> None:
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.queue = response_queue
+ self.query = query
+ self.history = history
+ self.response = ''
+ self.cache = []
+ self.received_inputs = False
+ self.queue.put((self.response, history + [(self.query, self.response)]))
+
+ def put(self, value):
+ if len(value.shape) > 1 and value.shape[0] > 1:
+ raise ValueError('ChatStreamer only supports batch size 1')
+ elif len(value.shape) > 1:
+ value = value[0]
+
+ if not self.received_inputs:
+ # The first received value is input_ids, ignore here
+ self.received_inputs = True
+ return
+
+ self.cache.extend(value.tolist())
+ token = self.tokenizer.decode(self.cache, skip_special_tokens=True)
+ if token.strip() != '<|im_end|>':
+ self.response = self.response + token
+ history = self.history + [(self.query, self.response)]
+ self.queue.put((self.response, history))
+ self.cache = []
+ else:
+ self.end()
+
+ def end(self):
+ self.queue.put(None)
+
+ def stream_producer():
+ return self.chat(
+ tokenizer=tokenizer,
+ query=query,
+ streamer=ChatStreamer(tokenizer=tokenizer),
+ history=history,
+ max_new_tokens=max_new_tokens,
+ do_sample=do_sample,
+ temperature=temperature,
+ top_p=top_p,
+ **kwargs,
+ )
+
+ def consumer():
+ producer = threading.Thread(target=stream_producer)
+ producer.start()
+ while True:
+ res = response_queue.get()
+ if res is None:
+ return
+ yield res
+
+ return consumer()
+
+
+# Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
+@add_start_docstrings(
+ """
+ The InternLM2 Model transformer with a sequence classification head on top (linear layer).
+
+ [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification,
+ as other causal models (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ InternLM2_START_DOCSTRING,
+)
+class InternLM2ForSequenceClassification(InternLM2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = InternLM2Model(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.tok_embeddings
+
+ def set_input_embeddings(self, value):
+ self.model.tok_embeddings = value
+
+ @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
+ logits.device
+ )
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = 'regression'
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = 'single_label_classification'
+ else:
+ self.config.problem_type = 'multi_label_classification'
+
+ if self.config.problem_type == 'regression':
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == 'single_label_classification':
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == 'multi_label_classification':
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
diff --git a/internvl/model/internlm2/tokenization_internlm2.py b/internvl/model/internlm2/tokenization_internlm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..1be581da37ef678de65f2737493fc0ed7160446e
--- /dev/null
+++ b/internvl/model/internlm2/tokenization_internlm2.py
@@ -0,0 +1,235 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization classes for InternLM."""
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
+
+PRETRAINED_VOCAB_FILES_MAP = {}
+
+
+# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
+class InternLM2Tokenizer(PreTrainedTokenizer):
+ """
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ model_input_names = ['input_ids', 'attention_mask']
+ _auto_class = 'AutoTokenizer'
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token='',
+ bos_token='',
+ eos_token='',
+ pad_token='',
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ add_bos_token=True,
+ add_eos_token=False,
+ decode_with_prefix_space=False,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ self.vocab_file = vocab_file
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self.decode_with_prefix_space = decode_with_prefix_space
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(vocab_file)
+ self._no_prefix_space_tokens = None
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ @property
+ def no_prefix_space_tokens(self):
+ if self._no_prefix_space_tokens is None:
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith('▁')}
+ return self._no_prefix_space_tokens
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.sp_model.get_piece_size()
+
+ @property
+ def bos_token_id(self) -> Optional[int]:
+ return self.sp_model.bos_id()
+
+ @property
+ def eos_token_id(self) -> Optional[int]:
+ return self.sp_model.eos_id()
+
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text):
+ """Returns a tokenized string."""
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.piece_to_id(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = self.sp_model.IdToPiece(index)
+ return token
+
+ def _maybe_add_prefix_space(self, tokens, decoded):
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
+ return ' ' + decoded
+ else:
+ return decoded
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ current_sub_tokens = []
+ out_string = ''
+ prev_is_special = False
+ for token in tokens:
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ if not prev_is_special:
+ out_string += ' '
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ prev_is_special = True
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ prev_is_special = False
+ out_string += self.sp_model.decode(current_sub_tokens)
+ out_string = self.clean_up_tokenization(out_string)
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
+ return out_string[1:]
+
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f'Vocabulary path ({save_directory}) should be a directory')
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, 'wb') as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ if self.add_bos_token:
+ bos_token_ids = [self.bos_token_id]
+ else:
+ bos_token_ids = []
+
+ output = bos_token_ids + token_ids_0
+
+ if token_ids_1 is not None:
+ output = output + token_ids_1
+
+ if self.add_eos_token:
+ output = output + [self.eos_token_id]
+
+ return output
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
+ use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ eos = [self.eos_token_id]
+
+ if token_ids_1 is None:
+ return len(token_ids_0 + eos) * [0]
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
diff --git a/internvl/model/internlm2/tokenization_internlm2_fast.py b/internvl/model/internlm2/tokenization_internlm2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa0fccbd0f1d029d79e19821f2edcb01b594537c
--- /dev/null
+++ b/internvl/model/internlm2/tokenization_internlm2_fast.py
@@ -0,0 +1,211 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization Fast class for InternLM."""
+import os
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple
+
+from tokenizers import Tokenizer, decoders, normalizers, processors
+from tokenizers.models import BPE
+from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS,
+ SentencePieceExtractor,
+ SpmConverter)
+from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
+from transformers.utils import logging
+
+from .tokenization_internlm2 import InternLM2Tokenizer
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
+
+
+# Modified from transformers.convert_slow_tokenizer.LlamaConverter
+class InternLM2Converter(SpmConverter):
+ handle_byte_fallback = True
+
+ def vocab(self, proto):
+ vocab = [
+ ('', 0.0),
+ ('', 0.0),
+ ('', 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ return vocab
+
+ def unk_id(self, proto):
+ unk_id = 0
+ return unk_id
+
+ def decoder(self, replacement, add_prefix_space):
+ return decoders.Sequence(
+ [
+ decoders.Replace('▁', ' '),
+ decoders.ByteFallback(),
+ decoders.Fuse(),
+ decoders.Strip(content=' ', left=1),
+ ]
+ )
+
+ def tokenizer(self, proto):
+ model_type = proto.trainer_spec.model_type
+ vocab_scores = self.vocab(proto)
+ # special tokens
+ added_tokens = self.original_tokenizer.added_tokens_decoder
+ for i in range(len(vocab_scores)):
+ piece, score = vocab_scores[i]
+ if i in added_tokens:
+ vocab_scores[i] = (added_tokens[i].content, score)
+ if model_type == 1:
+ raise RuntimeError('InternLM2 is supposed to be a BPE model!')
+
+ elif model_type == 2:
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
+ tokenizer = Tokenizer(
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
+ )
+ tokenizer.add_special_tokens(
+ [ added_token for index, added_token in added_tokens.items()]
+ )
+ else:
+ raise Exception(
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
+ )
+
+ return tokenizer
+
+ def normalizer(self, proto):
+ normalizers_list = []
+ if proto.normalizer_spec.add_dummy_prefix:
+ normalizers_list.append(normalizers.Prepend(prepend='▁'))
+ normalizers_list.append(normalizers.Replace(pattern=' ', content='▁'))
+ return normalizers.Sequence(normalizers_list)
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ return None
+
+
+SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter
+
+
+# Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
+class InternLM2TokenizerFast(PreTrainedTokenizerFast):
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = InternLM2Tokenizer
+ padding_side = 'left'
+ model_input_names = ['input_ids', 'attention_mask']
+ _auto_class = 'AutoTokenizer'
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token='',
+ bos_token='',
+ eos_token='',
+ pad_token='',
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ add_bos_token=True,
+ add_eos_token=False,
+ decode_with_prefix_space=False,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ sp_model_kwargs=sp_model_kwargs,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ decode_with_prefix_space=decode_with_prefix_space,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+ self._add_bos_token = add_bos_token
+ self._add_eos_token = add_eos_token
+ self.update_post_processor()
+ self.vocab_file = vocab_file
+
+ @property
+ def can_save_slow_tokenizer(self) -> bool:
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
+
+ def update_post_processor(self):
+ """
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
+ """
+ bos = self.bos_token
+ bos_token_id = self.bos_token_id
+ if bos is None and self.add_bos_token:
+ raise ValueError('add_bos_token = True but bos_token = None')
+
+ eos = self.eos_token
+ eos_token_id = self.eos_token_id
+ if eos is None and self.add_eos_token:
+ raise ValueError('add_eos_token = True but eos_token = None')
+
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
+
+ special_tokens = []
+ if self.add_bos_token:
+ special_tokens.append((bos, bos_token_id))
+ if self.add_eos_token:
+ special_tokens.append((eos, eos_token_id))
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=single, pair=pair, special_tokens=special_tokens
+ )
+
+ @property
+ def add_eos_token(self):
+ return self._add_eos_token
+
+ @property
+ def add_bos_token(self):
+ return self._add_bos_token
+
+ @add_eos_token.setter
+ def add_eos_token(self, value):
+ self._add_eos_token = value
+ self.update_post_processor()
+
+ @add_bos_token.setter
+ def add_bos_token(self, value):
+ self._add_bos_token = value
+ self.update_post_processor()
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
+ 'tokenizer.'
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f'Vocabulary path ({save_directory}) should be a directory')
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
diff --git a/internvl/model/internvl_chat/__init__.py b/internvl/model/internvl_chat/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d57341208843b8ac0376b4f8cd3e3b9186e3621
--- /dev/null
+++ b/internvl/model/internvl_chat/__init__.py
@@ -0,0 +1,13 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+from .configuration_intern_vit import InternVisionConfig
+from .configuration_internvl_chat import InternVLChatConfig
+from .modeling_intern_vit import InternVisionModel
+from .modeling_internvl_chat import InternVLChatModel
+
+__all__ = ['InternVisionConfig', 'InternVisionModel',
+ 'InternVLChatConfig', 'InternVLChatModel']
diff --git a/internvl/model/internvl_chat/__pycache__/__init__.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e81755f6af990c96b4a8d5e470a3d2cb30c90325
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/__init__.cpython-310.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/__init__.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1a1de5ed144acd5da2a4effba7536a27d1371dc
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/__init__.cpython-39.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49a8532ae09c8d389d82984ebaf67ba8a40c90b8
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-310.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..feb5af75a4978ce1222e0ad77cde59d0b08d71a7
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-39.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..811ee0d8f78aa3c0c096073de822a954768de549
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-310.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71e4f6c5fe4899fd18e5dfeac3200312d305b721
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-39.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..82264473109e907b5f0169ff9e194b6c64308c9b
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-310.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6e54913321f8bdf9252a241371c10d1841884c96
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-39.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e2d4e073dd27e2b99ffa8c69b5b02ba931f14a42
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-310.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5f8d763e6f76283c546c72eb0cefb0e156af6b9
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-39.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3976792dd9d27d4dde8d4affd3c2413b78b1dc33
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-310.pyc differ
diff --git a/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e870ef669b242afc8c54969591bbfaecb2267cb9
Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-39.pyc differ
diff --git a/internvl/model/internvl_chat/configuration_intern_vit.py b/internvl/model/internvl_chat/configuration_intern_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e630c456eb9cf350e55bf850c3ff72f445a7e17
--- /dev/null
+++ b/internvl/model/internvl_chat/configuration_intern_vit.py
@@ -0,0 +1,120 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import os
+from typing import Union
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class InternVisionConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ num_channels (`int`, *optional*, defaults to 3):
+ Number of color channels in the input images (e.g., 3 for RGB).
+ patch_size (`int`, *optional*, defaults to 14):
+ The size (resolution) of each patch.
+ image_size (`int`, *optional*, defaults to 224):
+ The size (resolution) of each image.
+ qkv_bias (`bool`, *optional*, defaults to `False`):
+ Whether to add a bias to the queries and values in the self-attention layers.
+ hidden_size (`int`, *optional*, defaults to 3200):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_attention_heads (`int`, *optional*, defaults to 25):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 12800):
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
+ qk_normalization (`bool`, *optional*, defaults to `True`):
+ Whether to normalize the queries and keys in the self-attention layers.
+ num_hidden_layers (`int`, *optional*, defaults to 48):
+ Number of hidden layers in the Transformer encoder.
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
+ Whether to use flash attention mechanism.
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
+ The epsilon used by the layer normalization layers.
+ dropout (`float`, *optional*, defaults to 0.0):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
+ Dropout rate for stochastic depth.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ initializer_factor (`float`, *optional*, defaults to 0.1):
+ A factor for layer scale.
+ """
+
+ model_type = 'intern_vit_6b'
+
+ def __init__(
+ self,
+ num_channels=3,
+ patch_size=14,
+ image_size=224,
+ qkv_bias=False,
+ hidden_size=3200,
+ num_attention_heads=25,
+ intermediate_size=12800,
+ qk_normalization=True,
+ num_hidden_layers=48,
+ use_flash_attn=True,
+ hidden_act='gelu',
+ norm_type='rms_norm',
+ layer_norm_eps=1e-6,
+ dropout=0.0,
+ drop_path_rate=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=0.1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.dropout = dropout
+ self.drop_path_rate = drop_path_rate
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.initializer_range = initializer_range
+ self.initializer_factor = initializer_factor
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.norm_type = norm_type
+ self.qkv_bias = qkv_bias
+ self.qk_normalization = qk_normalization
+ self.use_flash_attn = use_flash_attn
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ if 'vision_config' in config_dict:
+ config_dict = config_dict['vision_config']
+
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
diff --git a/internvl/model/internvl_chat/configuration_internvl_chat.py b/internvl/model/internvl_chat/configuration_internvl_chat.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb0e502c8ed44a2680693fc6e5ed1f53b01f3730
--- /dev/null
+++ b/internvl/model/internvl_chat/configuration_internvl_chat.py
@@ -0,0 +1,93 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import copy
+
+from transformers import AutoConfig
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+from .configuration_intern_vit import InternVisionConfig
+
+logger = logging.get_logger(__name__)
+
+
+class InternVLChatConfig(PretrainedConfig):
+ model_type = 'internvl_chat'
+ is_composition = True
+
+ def __init__(
+ self,
+ vision_config=None,
+ llm_config=None,
+ use_backbone_lora=0,
+ use_llm_lora=0,
+ pad2square=False,
+ select_layer=-1,
+ force_image_size=None,
+ hidden_size=2048,
+ downsample_ratio=0.5,
+ template=None,
+ dynamic_image_size=False,
+ use_thumbnail=False,
+ ps_version='v1',
+ min_dynamic_patch=1,
+ max_dynamic_patch=6,
+ **kwargs):
+ super().__init__(**kwargs)
+ # import pdb; pdb.set_trace()
+ if vision_config is None:
+ vision_config = {}
+ logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
+
+ self.vision_config = InternVisionConfig(**vision_config)
+ self.llm_config = None
+
+ self.use_backbone_lora = use_backbone_lora
+ self.use_llm_lora = use_llm_lora
+ self.pad2square = pad2square
+ self.select_layer = select_layer
+ self.force_image_size = force_image_size
+ self.downsample_ratio = downsample_ratio
+ self.template = template
+ self.dynamic_image_size = dynamic_image_size
+ self.use_thumbnail = use_thumbnail
+ self.ps_version = ps_version # pixel shuffle version
+ self.min_dynamic_patch = min_dynamic_patch
+ self.max_dynamic_patch = max_dynamic_patch
+
+ self.hidden_size = hidden_size
+ self.tie_word_embeddings = False
+
+ logger.info(f'vision_select_layer: {self.select_layer}')
+ logger.info(f'ps_version: {self.ps_version}')
+ logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
+ logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output['vision_config'] = self.vision_config.to_dict()
+ output['model_type'] = self.__class__.model_type
+ output['use_backbone_lora'] = self.use_backbone_lora
+ output['use_llm_lora'] = self.use_llm_lora
+ output['pad2square'] = self.pad2square
+ output['select_layer'] = self.select_layer
+ output['force_image_size'] = self.force_image_size
+ output['downsample_ratio'] = self.downsample_ratio
+ output['template'] = self.template
+ output['dynamic_image_size'] = self.dynamic_image_size
+ output['use_thumbnail'] = self.use_thumbnail
+ output['ps_version'] = self.ps_version
+ output['min_dynamic_patch'] = self.min_dynamic_patch
+ output['max_dynamic_patch'] = self.max_dynamic_patch
+
+ return output
diff --git a/internvl/model/internvl_chat/flash_attention.py b/internvl/model/internvl_chat/flash_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cda9bfadd290da35bdd04cccd51725e2d419c2f
--- /dev/null
+++ b/internvl/model/internvl_chat/flash_attention.py
@@ -0,0 +1,76 @@
+# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+try: # v1
+ from flash_attn.flash_attn_interface import \
+ flash_attn_unpadded_qkvpacked_func
+except: # v2
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+
+from flash_attn.bert_padding import pad_input, unpad_input
+
+
+class FlashAttention(nn.Module):
+ """Implement the scaled dot product attention with softmax.
+ Arguments
+ ---------
+ softmax_scale: The temperature to use for the softmax attention.
+ (default: 1/sqrt(d_keys) where d_keys is computed at
+ runtime)
+ attention_dropout: The dropout rate to apply to the attention
+ (default: 0.0)
+ """
+
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
+ super().__init__()
+ self.softmax_scale = softmax_scale
+ self.dropout_p = attention_dropout
+
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
+ max_s=None, need_weights=False):
+ """Implements the multihead softmax attention.
+ Arguments
+ ---------
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
+ if unpadded: (nnz, 3, h, d)
+ key_padding_mask: a bool tensor of shape (B, S)
+ """
+ assert not need_weights
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
+ assert qkv.is_cuda
+
+ if cu_seqlens is None:
+ batch_size = qkv.shape[0]
+ seqlen = qkv.shape[1]
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ max_s = seqlen
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
+ device=qkv.device)
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
+ indices, batch_size, seqlen),
+ 'b s (h d) -> b s h d', h=nheads)
+ else:
+ assert max_s is not None
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+
+ return output, None
diff --git a/internvl/model/internvl_chat/modeling_intern_vit.py b/internvl/model/internvl_chat/modeling_intern_vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b98fe88c188f602b370b5f1a55f0bced38eb489
--- /dev/null
+++ b/internvl/model/internvl_chat/modeling_intern_vit.py
@@ -0,0 +1,364 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+import numpy as np
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from einops import rearrange
+from timm.models.layers import DropPath
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import (BaseModelOutput,
+ BaseModelOutputWithPooling)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+
+from .configuration_intern_vit import InternVisionConfig
+
+try:
+ from .flash_attention import FlashAttention
+ has_flash_attn = True
+except:
+ print('FlashAttention is not installed.')
+ has_flash_attn = False
+
+logger = logging.get_logger(__name__)
+
+
+class InternRMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+try:
+ from apex.normalization import FusedRMSNorm
+
+ InternRMSNorm = FusedRMSNorm # noqa
+
+ logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
+except ImportError:
+ # using the normal InternRMSNorm
+ pass
+except Exception:
+ logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
+ pass
+
+
+NORM2FN = {
+ 'rms_norm': InternRMSNorm,
+ 'layer_norm': nn.LayerNorm,
+}
+
+
+class InternVisionEmbeddings(nn.Module):
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.class_embedding = nn.Parameter(
+ torch.randn(1, 1, self.embed_dim),
+ )
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches + 1
+
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
+
+ def _get_pos_embed(self, pos_embed, H, W):
+ target_dtype = pos_embed.dtype
+ pos_embed = pos_embed.float().reshape(
+ 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
+ pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \
+ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
+ return pos_embed
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ target_dtype = self.patch_embedding.weight.dtype
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
+ batch_size, _, height, width = patch_embeds.shape
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
+ position_embedding = torch.cat([
+ self.position_embedding[:, :1, :],
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
+ ], dim=1)
+ embeddings = embeddings + position_embedding.to(target_dtype)
+ return embeddings
+
+
+class InternAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
+ if config.use_flash_attn and not has_flash_attn:
+ print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
+ f' {self.num_heads}).'
+ )
+
+ self.scale = self.head_dim ** -0.5
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
+ self.attn_drop = nn.Dropout(config.attention_dropout)
+ self.proj_drop = nn.Dropout(config.dropout)
+
+ self.qk_normalization = config.qk_normalization
+
+ if self.qk_normalization:
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ if self.use_flash_attn:
+ self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def _naive_attn(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ if self.qk_normalization:
+ B_, H_, N_, D_ = q.shape
+ q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+ k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
+
+ attn = ((q * self.scale) @ k.transpose(-2, -1))
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
+ qkv = self.qkv(x)
+ qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
+
+ if self.qk_normalization:
+ q, k, v = qkv.unbind(2)
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
+ qkv = torch.stack([q, k, v], dim=2)
+
+ context, _ = self.inner_attn(
+ qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
+ )
+ outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
+ outs = self.proj_drop(outs)
+ return outs
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
+ return x
+
+
+class InternMLP(nn.Module):
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ self.act = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class InternVisionEncoderLayer(nn.Module):
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.norm_type = config.norm_type
+
+ self.attn = InternAttention(config)
+ self.mlp = InternMLP(config)
+ self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
+ self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
+
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
+ self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ """
+ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
+
+ hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
+
+ return hidden_states
+
+
+class InternVisionEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`InternEncoderLayer`].
+
+ Args:
+ config (`InternConfig`):
+ The corresponding vision configuration for the `InternEncoder`.
+ """
+
+ def __init__(self, config: InternVisionConfig):
+ super().__init__()
+ self.config = config
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
+ self.layers = nn.ModuleList([
+ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = True
+
+ def forward(
+ self,
+ inputs_embeds,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Embedded representation of the inputs. Should be float, not int tokens.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ hidden_states = inputs_embeds
+
+ for idx, encoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ encoder_layer,
+ hidden_states)
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ )
+ hidden_states = layer_outputs
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states
+ )
+
+
+class InternVisionModel(PreTrainedModel):
+ main_input_name = 'pixel_values'
+ _supports_flash_attn_2 = True
+ config_class = InternVisionConfig
+ _no_split_modules = ['InternVisionEncoderLayer']
+
+ def __init__(self, config: InternVisionConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = InternVisionEmbeddings(config)
+ self.encoder = InternVisionEncoder(config)
+
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
+ pos_emb = self.embeddings.position_embedding
+ _, num_positions, embed_dim = pos_emb.shape
+ cls_emb = pos_emb[:, :1, :]
+ pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
+ pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
+ self.embeddings.image_size = new_size
+ logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
+
+ def get_input_embeddings(self):
+ return self.embeddings
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_embeds: Optional[torch.FloatTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if pixel_values is None and pixel_embeds is None:
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
+
+ if pixel_embeds is not None:
+ hidden_states = pixel_embeds
+ else:
+ if len(pixel_values.shape) == 4:
+ hidden_states = self.embeddings(pixel_values)
+ else:
+ raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ last_hidden_state = encoder_outputs.last_hidden_state
+ pooled_output = last_hidden_state[:, 0, :]
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
diff --git a/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl/model/internvl_chat/modeling_internvl_chat.py
new file mode 100644
index 0000000000000000000000000000000000000000..618f0ca0cfe00edac96198930dfc11cbf2cce2cd
--- /dev/null
+++ b/internvl/model/internvl_chat/modeling_internvl_chat.py
@@ -0,0 +1,424 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import numpy as np
+import warnings
+from typing import Any, List, Optional, Tuple, Union
+import torch.nn.functional as F
+import torch.distributed as dist
+import torch.utils.checkpoint
+import transformers
+from internvl.conversation import get_conv_template
+from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
+from internvl.model.phi3.modeling_phi3 import Phi3ForCausalLM
+from peft import LoraConfig, get_peft_model
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
+ LlamaTokenizer, Qwen2ForCausalLM)
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import ModelOutput, logging
+
+from .configuration_internvl_chat import InternVLChatConfig
+from .modeling_intern_vit import InternVisionModel
+
+logger = logging.get_logger(__name__)
+
+
+def version_cmp(v1, v2, op='eq'):
+ import operator
+
+ from packaging import version
+ op_func = getattr(operator, op)
+ return op_func(version.parse(v1), version.parse(v2))
+
+
+class InternVLChatModel(PreTrainedModel):
+ config_class = InternVLChatConfig
+ main_input_name = 'pixel_values'
+ base_model_prefix = ''
+ _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer',
+ 'Phi3DecoderLayer', 'Qwen2DecoderLayer']
+ _supports_flash_attn_2 = True
+ supports_gradient_checkpointing = True
+
+ def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
+ super().__init__(config)
+
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
+ image_size = config.force_image_size or config.vision_config.image_size
+ patch_size = config.vision_config.patch_size
+ self.patch_size = patch_size
+ self.select_layer = config.select_layer
+ self.template = config.template
+ self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
+ self.downsample_ratio = config.downsample_ratio
+ self.ps_version = config.ps_version
+ # self.llm_arch_name = config.llm_config.architectures[0]
+
+ logger.info(f'num_image_token: {self.num_image_token}')
+ logger.info(f'ps_version: {self.ps_version}')
+ if vision_model is not None:
+ self.vision_model = vision_model
+ else:
+ self.vision_model = InternVisionModel(config.vision_config)
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, 2).to(torch.bfloat16)
+
+ # if language_model is not None:
+ # self.language_model = language_model
+ # else:
+ # if config.llm_config.architectures[0] == 'LlamaForCausalLM':
+ # self.language_model = LlamaForCausalLM(config.llm_config)
+ # elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
+ # self.language_model = InternLM2ForCausalLM(config.llm_config)
+ # elif config.llm_config.architectures[0] == 'Phi3ForCausalLM':
+ # self.language_model = Phi3ForCausalLM(config.llm_config)
+ # elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
+ # self.language_model = Qwen2ForCausalLM(config.llm_config)
+ # else:
+ # raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
+
+ vit_hidden_size = config.vision_config.hidden_size
+ llm_hidden_size = config.hidden_size
+
+ self.ocr_mlp = nn.Sequential(
+ nn.LayerNorm(vit_hidden_size),
+ nn.Linear(vit_hidden_size, llm_hidden_size),
+ nn.GELU(),
+ nn.Linear(llm_hidden_size, llm_hidden_size)
+ )
+ if config.train_stage > 1:
+ self.mlp1 = nn.Sequential(
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
+ nn.GELU(),
+ nn.Linear(llm_hidden_size, llm_hidden_size)
+ )
+
+ self.img_context_token_id = None
+ self.conv_template = get_conv_template(self.template)
+ if hasattr(config, 'system_message'):
+ self.system_message = config.system_message
+ else:
+ self.system_message = self.conv_template.system_message
+ self.num_samples = 0
+
+ if config.use_backbone_lora:
+ self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
+
+ if config.use_llm_lora:
+ self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
+
+ init_tau=np.log(10)
+ init_b=-2.71
+ self.t_prime = nn.Parameter(torch.ones([]) * init_tau)
+ self.b = nn.Parameter(torch.ones([]) * init_b)
+ self.kb = False
+ self.upsample = nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels=1024,
+ out_channels=512,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ bias=False
+ ),
+ # nn.BatchNorm2d(512),
+ nn.SyncBatchNorm(512),
+ # 第二层反卷积:进一步上采样到目标分辨率
+ nn.ConvTranspose2d(
+ in_channels=512,
+ out_channels=1024,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ bias=False
+ ),
+ # nn.BatchNorm2d(1024),
+ nn.SyncBatchNorm(1024),
+ )
+
+ def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
+ lora_config = LoraConfig(
+ r=r,
+ target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ )
+ self.vision_model = get_peft_model(self.vision_model, lora_config)
+ self.vision_model.print_trainable_parameters()
+
+ def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
+ # Determine the target modules based on the architecture of the language model
+ if self.llm_arch_name == 'InternLM2ForCausalLM':
+ target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3']
+ elif self.llm_arch_name == 'Phi3ForCausalLM':
+ target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj']
+ elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
+ target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
+ 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj']
+ else:
+ raise NotImplemented
+ lora_config = LoraConfig(
+ r=r,
+ target_modules=target_modules,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ task_type='CAUSAL_LM'
+ )
+ self.language_model = get_peft_model(self.language_model, lora_config)
+ self.language_model.enable_input_require_grads()
+ self.language_model.print_trainable_parameters()
+
+ def forward_tokenocr(self,
+ pixel_values: torch.FloatTensor)-> Union[Tuple, CausalLMOutputWithPast]:
+ vit_embeds = self.extract_feature_custom(pixel_values) #(vit_batch_size, 16*16, 2048)
+ # vit_embeds = self.extract_feature_custom_no_upsample(pixel_values) #(vit_batch_size, 16*16, 2048)
+ return vit_embeds, None
+
+
+
+ def pixel_unshuffle(self, x, scale_factor=4):
+ h = w = int(x.shape[1] ** 0.5)
+ n, l, c = x.size()
+ x = x.reshape(n, h, w, c)
+ x = x.repeat_interleave(scale_factor, dim=1).repeat_interleave(scale_factor, dim=2)
+ return x
+
+ def pixel_shuffle(self, x, scale_factor=0.5):
+ n, w, h, c = x.size()
+ # N, W, H, C --> N, W, H * scale, C // scale
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
+ x = x.permute(0, 2, 1, 3).contiguous()
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
+ int(c / (scale_factor * scale_factor)))
+ if self.ps_version == 'v1':
+ warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
+ 'which results in a transposed image.')
+ else:
+ x = x.permute(0, 2, 1, 3).contiguous()
+ return x
+
+ def extract_feature(self, pixel_values):
+ if self.select_layer == -1:
+ vit_embeds = self.vision_model(
+ pixel_values=pixel_values,
+ output_hidden_states=False,
+ return_dict=True).last_hidden_state
+ else:
+ vit_embeds = self.vision_model(
+ pixel_values=pixel_values,
+ output_hidden_states=True,
+ return_dict=True).hidden_states[self.select_layer]
+ vit_embeds = vit_embeds[:, 1:, :]
+
+ h = w = int(vit_embeds.shape[1] ** 0.5)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
+ vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
+ vit_embeds = self.mlp1(vit_embeds)
+ return vit_embeds
+
+ def extract_feature_custom(self, pixel_values):
+ if self.select_layer == -1:
+ vit_embeds = self.vision_model(
+ pixel_values=pixel_values,
+ output_hidden_states=False,
+ return_dict=True).last_hidden_state
+ else:
+ vit_embeds = self.vision_model(
+ pixel_values=pixel_values,
+ output_hidden_states=True,
+ return_dict=True).hidden_states[self.select_layer]
+ vit_embeds = vit_embeds[:, 1:, :] # (52, 1025, 1024)
+
+ h = w = int(vit_embeds.shape[1] ** 0.5)
+ vit_embeds = vit_embeds.permute(0,2,1).reshape(vit_embeds.shape[0], -1, h, w)
+ vit_embeds = self.upsample(vit_embeds)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-2] * vit_embeds.shape[-1])
+ vit_embeds = self.ocr_mlp(vit_embeds.permute(0, 2, 1))
+ return vit_embeds
+
+ def extract_feature_custom_no_upsample(self, pixel_values):
+ if self.select_layer == -1:
+ vit_embeds = self.vision_model(
+ pixel_values=pixel_values,
+ output_hidden_states=False,
+ return_dict=True).last_hidden_state
+ else:
+ vit_embeds = self.vision_model(
+ pixel_values=pixel_values,
+ output_hidden_states=True,
+ return_dict=True).hidden_states[self.select_layer]
+ vit_embeds = vit_embeds[:, 1:, :] # (52, 1025, 1024)
+
+ h = w = int(vit_embeds.shape[1] ** 0.5)
+ vit_embeds = self.ocr_mlp(vit_embeds)
+ # vit_embeds = self.pixel_unshuffle(vit_embeds)
+ # vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
+ return vit_embeds
+
+ def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
+ history=None, return_history=False, IMG_START_TOKEN='
', IMG_END_TOKEN='',
+ IMG_CONTEXT_TOKEN='', verbose=False, image_counts=None):
+ if history is not None or return_history:
+ print('Now multi-turn chat is not supported in batch_chat.')
+ raise NotImplementedError
+
+ if image_counts is not None:
+ num_patches_list = image_counts
+ print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
+
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
+ self.img_context_token_id = img_context_token_id
+
+ if verbose and pixel_values is not None:
+ image_bs = pixel_values.shape[0]
+ print(f'dynamic ViT batch size: {image_bs}')
+
+ queries = []
+ for idx, num_patches in enumerate(num_patches_list):
+ question = questions[idx]
+ if pixel_values is not None and '' not in question:
+ question = '\n' + question
+ template = get_conv_template(self.template)
+ template.system_message = self.system_message
+ template.append_message(template.roles[0], question)
+ template.append_message(template.roles[1], None)
+ query = template.get_prompt()
+
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
+ query = query.replace('', image_tokens, 1)
+ queries.append(query)
+
+ tokenizer.padding_side = 'left'
+ model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
+ input_ids = model_inputs['input_ids'].cuda()
+ attention_mask = model_inputs['attention_mask'].cuda()
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
+ generation_config['eos_token_id'] = eos_token_id
+ generation_output = self.generate(
+ pixel_values=pixel_values,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ **generation_config
+ )
+ responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
+ responses = [response.split(template.sep.strip())[0].strip() for response in responses]
+ return responses
+
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
+ num_patches_list=None, IMG_START_TOKEN='
', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='',
+ verbose=False):
+
+ if history is None and pixel_values is not None and '' not in question:
+ question = '\n' + question
+
+ if num_patches_list is None:
+ num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
+ assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
+
+ img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
+ self.img_context_token_id = img_context_token_id
+
+ template = get_conv_template(self.template)
+ template.system_message = self.system_message
+ eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip())
+
+ history = [] if history is None else history
+ for (old_question, old_answer) in history:
+ template.append_message(template.roles[0], old_question)
+ template.append_message(template.roles[1], old_answer)
+ template.append_message(template.roles[0], question)
+ template.append_message(template.roles[1], None)
+ query = template.get_prompt()
+
+ if verbose and pixel_values is not None:
+ image_bs = pixel_values.shape[0]
+ print(f'dynamic ViT batch size: {image_bs}')
+
+ for num_patches in num_patches_list:
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
+ query = query.replace('', image_tokens, 1)
+
+ model_inputs = tokenizer(query, return_tensors='pt')
+ input_ids = model_inputs['input_ids'].cuda()
+ attention_mask = model_inputs['attention_mask'].cuda()
+ generation_config['eos_token_id'] = eos_token_id
+ generation_output = self.generate(
+ pixel_values=pixel_values,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ **generation_config
+ )
+ response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
+ response = response.split(template.sep.strip())[0].strip()
+ history.append((question, response))
+ if return_history:
+ return response, history
+ else:
+ query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
+ query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '')
+ if verbose:
+ print(query_to_print, response)
+ return response
+
+ @torch.no_grad()
+ def generate(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ input_ids: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ visual_features: Optional[torch.FloatTensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **generate_kwargs,
+ ) -> torch.LongTensor:
+
+ assert self.img_context_token_id is not None
+ if pixel_values is not None:
+ if visual_features is not None:
+ vit_embeds = visual_features
+ else:
+ vit_embeds = self.extract_feature(pixel_values)
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
+ B, N, C = input_embeds.shape
+ input_embeds = input_embeds.reshape(B * N, C)
+
+ input_ids = input_ids.reshape(B * N)
+ selected = (input_ids == self.img_context_token_id)
+ assert selected.sum() != 0
+ input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
+
+ input_embeds = input_embeds.reshape(B, N, C)
+ else:
+ input_embeds = self.language_model.get_input_embeddings()(input_ids)
+
+ outputs = self.language_model.generate(
+ inputs_embeds=input_embeds,
+ attention_mask=attention_mask,
+ generation_config=generation_config,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ use_cache=True,
+ **generate_kwargs,
+ )
+
+ return outputs
+
+ @property
+ def lm_head(self):
+ return self.language_model.get_output_embeddings()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
diff --git a/internvl/model/phi3/__pycache__/configuration_phi3.cpython-310.pyc b/internvl/model/phi3/__pycache__/configuration_phi3.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a08c18584f23108bd92a2b9f30ae5b11561fd1f
Binary files /dev/null and b/internvl/model/phi3/__pycache__/configuration_phi3.cpython-310.pyc differ
diff --git a/internvl/model/phi3/__pycache__/configuration_phi3.cpython-39.pyc b/internvl/model/phi3/__pycache__/configuration_phi3.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..69240f7ea2679fc67a9ee3ae76d9ff72fcb62331
Binary files /dev/null and b/internvl/model/phi3/__pycache__/configuration_phi3.cpython-39.pyc differ
diff --git a/internvl/model/phi3/__pycache__/modeling_phi3.cpython-310.pyc b/internvl/model/phi3/__pycache__/modeling_phi3.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..97c5684f72df959c344a8e775412519e295d882e
Binary files /dev/null and b/internvl/model/phi3/__pycache__/modeling_phi3.cpython-310.pyc differ
diff --git a/internvl/model/phi3/__pycache__/modeling_phi3.cpython-39.pyc b/internvl/model/phi3/__pycache__/modeling_phi3.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..30e2ed25a0480075cfdbba62c66dbce3419de256
Binary files /dev/null and b/internvl/model/phi3/__pycache__/modeling_phi3.cpython-39.pyc differ
diff --git a/internvl/model/phi3/configuration_phi3.py b/internvl/model/phi3/configuration_phi3.py
new file mode 100644
index 0000000000000000000000000000000000000000..c657051097ebd7655786d74f8ed75635bfc844c4
--- /dev/null
+++ b/internvl/model/phi3/configuration_phi3.py
@@ -0,0 +1,211 @@
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License atd
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Phi-3 model configuration"""
+
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
+ 'microsoft/Phi-3-mini-4k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json',
+ 'microsoft/Phi-3-mini-128k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json',
+}
+
+
+class Phi3Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
+ defaults will yield a similar configuration to that of the
+ [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32064):
+ Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Phi3Model`].
+ hidden_size (`int`, *optional*, defaults to 3072):
+ Dimension of the hidden representations.
+ intermediate_size (`int`, *optional*, defaults to 8192):
+ Dimension of the MLP representations.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
+ `num_attention_heads`.
+ resid_pdrop (`float`, *optional*, defaults to 0.0):
+ Dropout probability for mlp outputs.
+ embd_pdrop (`int`, *optional*, defaults to 0.0):
+ The dropout ratio for the embeddings.
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio after computing the attention scores.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string) in the decoder.
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with.
+ original_max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model was trained with. This is used to determine the size of the
+ original RoPE embeddings when using long scaling.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon value used for the RMSNorm.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
+ Whether to tie weight embeddings
+ rope_theta (`float`, *optional*, defaults to 10000.0):
+ The base period of the RoPE embeddings.
+ rope_scaling (`dict`, *optional*):
+ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
+ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
+ the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
+ divided by the number of attention heads divided by 2.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 32000):
+ The id of the "end-of-sequence" token.
+ pad_token_id (`int`, *optional*, defaults to 32000):
+ The id of the padding token.
+ sliding_window (`int`, *optional*):
+ Sliding window attention window size. If `None`, no sliding window is applied.
+
+ Example:
+
+ ```python
+ >>> from transformers import Phi3Model, Phi3Config
+
+ >>> # Initializing a Phi-3 style configuration
+ >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
+
+ >>> # Initializing a model from the configuration
+ >>> model = Phi3Model(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = 'phi3'
+ keys_to_ignore_at_inference = ['past_key_values']
+
+ def __init__(
+ self,
+ vocab_size=32064,
+ hidden_size=3072,
+ intermediate_size=8192,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ resid_pdrop=0.0,
+ embd_pdrop=0.0,
+ attention_dropout=0.0,
+ hidden_act='silu',
+ max_position_embeddings=4096,
+ original_max_position_embeddings=4096,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ tie_word_embeddings=False,
+ rope_theta=10000.0,
+ rope_scaling=None,
+ bos_token_id=1,
+ eos_token_id=32000,
+ pad_token_id=32000,
+ sliding_window=None,
+ **kwargs,
+ ):
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+
+ self.num_key_value_heads = num_key_value_heads
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attention_dropout = attention_dropout
+ self.hidden_act = hidden_act
+ self.max_position_embeddings = max_position_embeddings
+ self.original_max_position_embeddings = original_max_position_embeddings
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.rope_theta = rope_theta
+ self.rope_scaling = rope_scaling
+ self._rope_scaling_validation()
+ self.sliding_window = sliding_window
+
+ super().__init__(
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ tie_word_embeddings=tie_word_embeddings,
+ **kwargs,
+ )
+
+ def _rope_scaling_validation(self):
+ """
+ Validate the `rope_scaling` configuration.
+ """
+ if self.rope_scaling is None:
+ return
+
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
+ raise ValueError(
+ '`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, '
+ f'got {self.rope_scaling}'
+ )
+ rope_scaling_type = self.rope_scaling.get('type', None)
+ rope_scaling_short_factor = self.rope_scaling.get('short_factor', None)
+ rope_scaling_long_factor = self.rope_scaling.get('long_factor', None)
+ if rope_scaling_type is None or rope_scaling_type not in ['su', 'yarn']:
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
+ if not (
+ isinstance(rope_scaling_short_factor, list)
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
+ )
+ if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
+ raise ValueError(
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
+ )
+ if not (
+ isinstance(rope_scaling_long_factor, list)
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
+ ):
+ raise ValueError(
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
+ )
+ if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
+ raise ValueError(
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
+ )
diff --git a/internvl/model/phi3/modeling_phi3.py b/internvl/model/phi3/modeling_phi3.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bb5dc0450c0435262f5d32afa0533ad4161d92d
--- /dev/null
+++ b/internvl/model/phi3/modeling_phi3.py
@@ -0,0 +1,1610 @@
+# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" PyTorch Phi-3 model."""
+
+import inspect
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_attn_mask_utils import \
+ _prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import (BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import (add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_2_available,
+ is_flash_attn_greater_or_equal_2_10, logging,
+ replace_return_docstrings)
+
+from .configuration_phi3 import Phi3Config
+
+logger = logging.get_logger(__name__)
+
+# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
+# if is_flash_attn_2_available():
+_flash_supports_window_size = False
+try:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa
+ unpad_input)
+
+ _flash_supports_window_size = 'window_size' in list(inspect.signature(flash_attn_func).parameters)
+ has_flash_attn = True
+except ImportError as error:
+ logger.warning(
+ f'`flash-attention` package not found, consider installing for better performance: {error}.'
+ )
+ if not _flash_supports_window_size:
+ logger.warning(
+ "Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
+ )
+ has_flash_attn = False
+
+_CHECKPOINT_FOR_DOC = 'microsoft/Phi-3-mini-4k-instruct'
+_CONFIG_FOR_DOC = 'Phi3Config'
+
+PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ 'microsoft/Phi-3-mini-4k-instruct',
+ 'microsoft/Phi-3-mini-128k-instruct',
+ # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
+]
+
+
+# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
+class Phi3RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Phi3RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama._get_unpad_data
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
+class Phi3RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.register_buffer('inv_freq', None, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if self.inv_freq is None:
+ self.inv_freq = 1.0 / (
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
+ )
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
+ def __init__(self, dim, config, device=None):
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
+
+ self.short_factor = config.rope_scaling['short_factor']
+ self.long_factor = config.rope_scaling['long_factor']
+ self.original_max_position_embeddings = config.original_max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.original_max_position_embeddings:
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
+ else:
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
+
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
+
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
+ if scale <= 1.0:
+ scaling_factor = 1.0
+ else:
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
+
+ cos = emb.cos() * scaling_factor
+ sin = emb.sin() * scaling_factor
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
+ def __init__(self, dim, config, device=None):
+ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
+
+ self.short_factor = config.rope_scaling['short_factor']
+ self.long_factor = config.rope_scaling['long_factor']
+ self.original_max_position_embeddings = config.original_max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.original_max_position_embeddings:
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
+ else:
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
+
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
+
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
+ if scale <= 1.0:
+ scaling_factor = 1.0
+ else:
+ scaling_factor = 0.1 * math.log(scale) + 1.0
+
+ cos = emb.cos() * scaling_factor
+ sin = emb.sin() * scaling_factor
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+# Copied from transformers.models.llama.modeling_llama.rotate_half
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class Phi3MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
+
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
+ up_states = self.gate_up_proj(hidden_states)
+
+ gate, up_states = up_states.chunk(2, dim=-1)
+ up_states = up_states * self.activation_fn(gate)
+
+ return self.down_proj(up_states)
+
+
+# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class Phi3Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ if layer_idx is None:
+ logger.warning_once(
+ f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
+ 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
+ 'when creating this class.'
+ )
+
+ self.attention_dropout = config.attention_dropout
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.original_max_position_embeddings = config.original_max_position_embeddings
+ self.rope_theta = config.rope_theta
+ self.rope_scaling = config.rope_scaling
+ self.is_causal = True
+
+ if (self.head_dim * self.num_heads) != self.hidden_size:
+ raise ValueError(
+ f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
+ f' and `num_heads`: {self.num_heads}).'
+ )
+
+ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
+ self._init_rope()
+
+ def _init_rope(self):
+ if self.rope_scaling is None:
+ self.rotary_emb = Phi3RotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=self.max_position_embeddings,
+ base=self.rope_theta,
+ )
+ else:
+ scaling_type = self.config.rope_scaling['type']
+ if scaling_type == 'su':
+ self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
+ elif scaling_type == 'yarn':
+ self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
+ else:
+ raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ logger.warning_once('You are not running the flash-attention implementation, expect numerical differences.')
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv_proj(hidden_states)
+ query_pos = self.num_heads * self.head_dim
+ query_states = qkv[..., :query_pos]
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
+ 'with a layer index.'
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
+ f' {attn_weights.size()}'
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
+ f' {attn_output.size()}'
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class Phi3FlashAttention2(Phi3Attention):
+ """
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # Phi3FlashAttention2 attention does not support output_attentions
+
+ if not _flash_supports_window_size:
+ logger.warning_once(
+ "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
+ )
+ raise ValueError('The current flash attention version does not support sliding window attention.')
+
+ output_attentions = False
+
+ if 'padding_mask' in kwargs:
+ warnings.warn(
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
+ )
+
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop('padding_mask')
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv_proj(hidden_states)
+ query_pos = self.num_heads * self.head_dim
+ query_states = qkv[..., :query_pos]
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
+ 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
+ 'with a layer index.'
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ use_sliding_windows = (
+ _flash_supports_window_size
+ and getattr(self.config, 'sliding_window', None) is not None
+ and kv_seq_len > self.config.sliding_window
+ )
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, 'sliding_window', None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f'past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got'
+ f' {past_key.shape}'
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_dropout = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32.
+
+ if query_states.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, '_pre_quantization_dtype'):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.qkv_proj.weight.dtype
+
+ logger.warning_once(
+ f'The input hidden states seems to be silently casted in float32, this might be related to'
+ f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
+ f' {target_dtype}.'
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = self._flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=attn_dropout,
+ use_sliding_windows=use_sliding_windows,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ use_sliding_windows=False,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_sliding_windows (`bool`, *optional*):
+ Whether to activate sliding window attention.
+ """
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ if not use_sliding_windows:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+ else:
+ if not use_sliding_windows:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output = flash_attn_func(
+ query_states,
+ key_states,
+ value_states,
+ dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ return attn_output
+
+ # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
+ batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
+
+ # On the first iteration we need to properly re-create the padding mask
+ # by slicing it on the proper place
+ if kv_seq_len != attention_mask.shape[-1]:
+ attention_mask_num_tokens = attention_mask.shape[-1]
+ attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
+
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+ value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
+
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
+ )
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
+# TODO @Arthur no longer copied from LLama after static cache
+class Phi3SdpaAttention(Phi3Attention):
+ """
+ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ # Adapted from Phi3Attention.forward
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ 'Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, '
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv_proj(hidden_states)
+ query_pos = self.num_heads * self.head_dim
+ query_states = qkv[..., :query_pos]
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
+ )
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == 'cuda' and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+PHI3_ATTENTION_CLASSES = {
+ 'eager': Phi3Attention,
+ 'flash_attention_2': Phi3FlashAttention2,
+ 'sdpa': Phi3SdpaAttention,
+}
+
+
+class Phi3DecoderLayer(nn.Module):
+ def __init__(self, config: Phi3Config, layer_idx: int):
+ super().__init__()
+
+ self.config = config
+ self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
+
+ self.mlp = Phi3MLP(config)
+ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
+ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
+ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ if 'padding_mask' in kwargs:
+ warnings.warn(
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
+ )
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
+ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = residual + self.resid_attn_dropout(attn_outputs)
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + self.resid_mlp_dropout(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+PHI3_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Phi3Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
+ PHI3_START_DOCSTRING,
+)
+class Phi3PreTrainedModel(PreTrainedModel):
+ config_class = Phi3Config
+ base_model_prefix = 'model'
+ supports_gradient_checkpointing = True
+ _no_split_modules = ['Phi3DecoderLayer']
+ _skip_keys_device_placement = 'past_key_values'
+ _supports_flash_attn_2 = True
+ _supports_sdpa = False
+ _supports_cache_class = True
+
+ _version = '0.0.5'
+
+ def __init__(self, config: Phi3Config):
+ if not has_flash_attn:
+ config._attn_implementation = 'eager'
+ print('Warning: Flash attention is not available, using eager attention instead.')
+ super().__init__(config)
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+PHI3_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
+
+ Two formats are allowed:
+ - a [`~cache_utils.Cache`] instance;
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
+ cache format.
+
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
+ legacy cache format will be returned.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
+ of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
+ PHI3_START_DOCSTRING,
+)
+class Phi3Model(Phi3PreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
+
+ Args:
+ config: Phi3Config
+ """
+
+ def __init__(self, config: Phi3Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
+ self.layers = nn.ModuleList(
+ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self._attn_implementation = config._attn_implementation
+
+ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError('You have to specify either input_ids or inputs_embeds')
+
+ past_key_values_length = 0
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
+ )
+ use_cache = False
+
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if attention_mask is not None and self._attn_implementation == 'flash_attention_2' and use_cache:
+ is_padding_right = attention_mask[:, -1].sum().item() != batch_size
+ if is_padding_right:
+ raise ValueError(
+ "You are attempting to perform batched generation with padding_side='right'"
+ ' this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to '
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
+ )
+
+ if self._attn_implementation == 'flash_attention_2':
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class Phi3ForCausalLM(Phi3PreTrainedModel):
+ _tied_weights_keys = ['lm_head.weight']
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Phi3Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
+ def get_decoder(self):
+ return self.model
+
+ # Ignore copy
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Phi3ForCausalLM
+
+ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
+ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
+
+ >>> prompt = "This is an example script ."
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = CrossEntropyLoss()
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
+ shift_labels = shift_labels.view(-1)
+ # Enable model parallelism
+ shift_labels = shift_labels.to(shift_logits.device)
+ loss = loss_fct(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ past_length = past_key_values.seen_tokens
+ max_cache_length = past_key_values.get_max_length()
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and cache_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get('position_ids', None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if (inputs_embeds is not None and past_key_values is None) or (inputs_embeds is not None and len(past_key_values) == 0):
+ model_inputs = {'inputs_embeds': inputs_embeds}
+ else:
+ model_inputs = {'input_ids': input_ids}
+
+ model_inputs.update(
+ {
+ 'position_ids': position_ids,
+ 'past_key_values': past_key_values,
+ 'use_cache': kwargs.get('use_cache'),
+ 'attention_mask': attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+@add_start_docstrings(
+ """
+ The [`Phi3Model`] with a sequence classification head on top (linear layer).
+
+ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ PHI3_START_DOCSTRING,
+)
+# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
+class Phi3ForSequenceClassification(Phi3PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = Phi3Model(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ model_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = model_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = 'regression'
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = 'single_label_classification'
+ else:
+ self.config.problem_type = 'multi_label_classification'
+
+ if self.config.problem_type == 'regression':
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == 'single_label_classification':
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == 'multi_label_classification':
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + model_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=model_outputs.past_key_values,
+ hidden_states=model_outputs.hidden_states,
+ attentions=model_outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ PHI3_START_DOCSTRING,
+)
+# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
+class Phi3ForTokenClassification(Phi3PreTrainedModel):
+ def __init__(self, config: Phi3Config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.model = Phi3Model(config)
+ if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ elif hasattr(config, 'hidden_dropout') and config.hidden_dropout is not None:
+ classifier_dropout = config.hidden_dropout
+ else:
+ classifier_dropout = 0.1
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ model_outputs = self.model(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = model_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ batch_size, seq_length = labels.shape
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
+ )
+
+ if not return_dict:
+ output = (logits,) + model_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=model_outputs.hidden_states,
+ attentions=model_outputs.attentions,
+ )
diff --git a/internvl/patch/__init__.py b/internvl/patch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f84fe933ba5e62ac4329f1b2db164d22b9df12
--- /dev/null
+++ b/internvl/patch/__init__.py
@@ -0,0 +1,34 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+from .internlm2_packed_training_patch import replace_internlm2_attention_class
+from .internvit_liger_monkey_patch import apply_liger_kernel_to_internvit
+from .llama2_flash_attn_monkey_patch import replace_llama2_attn_with_flash_attn
+from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
+from .llama_packed_training_patch import replace_llama_attention_class
+from .llama_rmsnorm_monkey_patch import \
+ replace_llama_rmsnorm_with_fused_rmsnorm
+from .pad_data_collator import (concat_pad_data_collator,
+ dpo_concat_pad_data_collator,
+ pad_data_collator)
+from .phi3_packed_training_patch import replace_phi3_attention_class
+from .qwen2_packed_training_patch import replace_qwen2_attention_class
+from .train_dataloader_patch import replace_train_dataloader
+from .train_sampler_patch import replace_train_sampler
+
+__all__ = ['replace_llama_attn_with_flash_attn',
+ 'replace_llama_rmsnorm_with_fused_rmsnorm',
+ 'replace_llama2_attn_with_flash_attn',
+ 'replace_train_sampler',
+ 'replace_train_dataloader',
+ 'replace_internlm2_attention_class',
+ 'replace_qwen2_attention_class',
+ 'replace_phi3_attention_class',
+ 'replace_llama_attention_class',
+ 'pad_data_collator',
+ 'dpo_concat_pad_data_collator',
+ 'concat_pad_data_collator',
+ 'apply_liger_kernel_to_internvit']
diff --git a/internvl/patch/__pycache__/__init__.cpython-39.pyc b/internvl/patch/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..719c9fd406cc93c434d720d4f1457c0a89f50815
Binary files /dev/null and b/internvl/patch/__pycache__/__init__.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/internlm2_packed_training_patch.cpython-39.pyc b/internvl/patch/__pycache__/internlm2_packed_training_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce626181c4880ac68e765c49c27d58cb7841c1f6
Binary files /dev/null and b/internvl/patch/__pycache__/internlm2_packed_training_patch.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/internvit_liger_monkey_patch.cpython-39.pyc b/internvl/patch/__pycache__/internvit_liger_monkey_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c66f6ccde4b246a8b791b1266b41c866860807f1
Binary files /dev/null and b/internvl/patch/__pycache__/internvit_liger_monkey_patch.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/llama2_flash_attn_monkey_patch.cpython-39.pyc b/internvl/patch/__pycache__/llama2_flash_attn_monkey_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e6d17d0ca475b1474d0e7b5f7d59b5003d8b407
Binary files /dev/null and b/internvl/patch/__pycache__/llama2_flash_attn_monkey_patch.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc b/internvl/patch/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d27cda75494b01780dc1caf565fb902500953806
Binary files /dev/null and b/internvl/patch/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/llama_packed_training_patch.cpython-39.pyc b/internvl/patch/__pycache__/llama_packed_training_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f3e97973d10abf40cbf8fff67ccdd7a5e029253
Binary files /dev/null and b/internvl/patch/__pycache__/llama_packed_training_patch.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/llama_rmsnorm_monkey_patch.cpython-39.pyc b/internvl/patch/__pycache__/llama_rmsnorm_monkey_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18aef4261936aa693ad1c3ceba294a36d13ef51e
Binary files /dev/null and b/internvl/patch/__pycache__/llama_rmsnorm_monkey_patch.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/pad_data_collator.cpython-39.pyc b/internvl/patch/__pycache__/pad_data_collator.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9ef9db1f0117bd90a1f145a5da27278b52ec236
Binary files /dev/null and b/internvl/patch/__pycache__/pad_data_collator.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/phi3_packed_training_patch.cpython-39.pyc b/internvl/patch/__pycache__/phi3_packed_training_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d88d3bbbbdb30bbec35cf44fd5f9418b6dd0c37c
Binary files /dev/null and b/internvl/patch/__pycache__/phi3_packed_training_patch.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/qwen2_packed_training_patch.cpython-39.pyc b/internvl/patch/__pycache__/qwen2_packed_training_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec200a6c9a2dba76069834e18d91195b20fa83cd
Binary files /dev/null and b/internvl/patch/__pycache__/qwen2_packed_training_patch.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/train_dataloader_patch.cpython-39.pyc b/internvl/patch/__pycache__/train_dataloader_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ffeb0a267ace08d24f1a6ca8fa9afcb3b26be50b
Binary files /dev/null and b/internvl/patch/__pycache__/train_dataloader_patch.cpython-39.pyc differ
diff --git a/internvl/patch/__pycache__/train_sampler_patch.cpython-39.pyc b/internvl/patch/__pycache__/train_sampler_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3f500d392e711a0c32b90d128d969b92eb801903
Binary files /dev/null and b/internvl/patch/__pycache__/train_sampler_patch.cpython-39.pyc differ
diff --git a/internvl/patch/internlm2_packed_training_patch.py b/internvl/patch/internlm2_packed_training_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0805e63876f23645c1802b163bbdb844de5fe27
--- /dev/null
+++ b/internvl/patch/internlm2_packed_training_patch.py
@@ -0,0 +1,74 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import torch
+from flash_attn.flash_attn_interface import flash_attn_varlen_func
+from internvl.model.internlm2.modeling_internlm2 import (
+ INTERNLM2_ATTENTION_CLASSES, InternLM2FlashAttention2,
+ apply_rotary_pos_emb)
+
+
+# Modified from internvl.model.internlm2.modeling_internlm2.InternLM2FlashAttention2
+class InternLM2FlashAttention2ForPackedTraining(InternLM2FlashAttention2):
+
+ def _flash_attention_forward(
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
+ of the sequences in the batch.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ """
+ assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
+ query_states = query_states.squeeze(0)
+ key_states = key_states.squeeze(0)
+ value_states = value_states.squeeze(0)
+ cu_seqlens = attention_mask.squeeze(0)
+
+ with torch.no_grad():
+ max_seqlen = max([
+ cu_seqlens[idx+1] - cu_seqlens[idx]
+ for idx in range(cu_seqlens.size(0) - 1)
+ ]).item()
+
+ # Contains at least one padding token in the sequence
+ causal = self.is_causal and query_length != 1
+ attn_output = flash_attn_varlen_func(
+ q=query_states,
+ k=key_states,
+ v=value_states,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_k=max_seqlen,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+
+ query_states = query_states.unsqueeze(0)
+ key_states = key_states.unsqueeze(0)
+ value_states = value_states.unsqueeze(0)
+ return attn_output
+
+
+def replace_internlm2_attention_class():
+ INTERNLM2_ATTENTION_CLASSES['flash_attention_2'] = InternLM2FlashAttention2ForPackedTraining
+ print('Replace INTERNLM2_ATTENTION_CLASSES to support packed training!!')
diff --git a/internvl/patch/internvit_liger_monkey_patch.py b/internvl/patch/internvit_liger_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..154aab1d1884968530329e08123276ba71faa07f
--- /dev/null
+++ b/internvl/patch/internvit_liger_monkey_patch.py
@@ -0,0 +1,13 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+def apply_liger_kernel_to_internvit() -> None:
+ from internvl.model.internvl_chat import modeling_intern_vit
+ from liger_kernel.transformers.layer_norm import LigerLayerNorm
+ from liger_kernel.transformers.rms_norm import LigerRMSNorm
+ modeling_intern_vit.NORM2FN['rms_norm'] = LigerRMSNorm
+ modeling_intern_vit.NORM2FN['layer_norm'] = LigerLayerNorm
+ print('Liger kernel applied to InternViT')
diff --git a/internvl/patch/llama2_flash_attn_monkey_patch.py b/internvl/patch/llama2_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..01091ef2a0c1197b4638a4b8d95214670ac94039
--- /dev/null
+++ b/internvl/patch/llama2_flash_attn_monkey_patch.py
@@ -0,0 +1,238 @@
+"""
+This file is copied from: https://github.com/lm-sys/FastChat
+"""
+
+import warnings
+from typing import Optional, Tuple
+
+import torch
+from flash_attn import __version__ as flash_attn_version
+from flash_attn.bert_padding import pad_input, unpad_input
+from flash_attn.flash_attn_interface import (flash_attn_func,
+ flash_attn_varlen_kvpacked_func)
+from transformers.models.llama.modeling_llama import (LlamaAttention,
+ LlamaModel, rotate_half)
+
+
+def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
+ gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
+ gather_indices = gather_indices.repeat(
+ 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
+ )
+ bsz = gather_indices.shape[0]
+ cos, sin = (
+ torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
+ for x in cos_sin
+ )
+ q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
+ return q, k
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn(
+ 'Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.'
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+ kv_heads = getattr(self, 'num_key_value_heads', self.num_heads)
+
+ q, k, v = (
+ op(hidden_states).view(bsz, q_len, nh, self.head_dim)
+ for op, nh in (
+ (self.q_proj, self.num_heads),
+ (self.k_proj, kv_heads),
+ (self.v_proj, kv_heads),
+ )
+ )
+ # shape: (b, s, num_heads, head_dim)
+
+ kv_seq_len = k.shape[1]
+ past_kv_len = 0
+ if past_key_value is not None:
+ past_kv_len = past_key_value[0].shape[2]
+ kv_seq_len += past_kv_len
+
+ cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
+ q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
+
+ if past_key_value is not None:
+ assert (
+ flash_attn_version >= '2.1.0'
+ ), 'past_key_value support requires flash-attn >= 2.1.0'
+ # reuse k, v
+ k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
+ v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
+
+ past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
+
+ if attention_mask is None:
+ output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
+ bsz, q_len, -1
+ )
+ else:
+ q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
+ # We can skip concat and call unpad twice but seems better to call unpad only once.
+ kv, _, cu_k_lens, max_k = unpad_input(
+ torch.stack((k, v), dim=2), attention_mask
+ )
+ output_unpad = flash_attn_varlen_kvpacked_func(
+ q,
+ kv,
+ cu_q_lens,
+ cu_k_lens,
+ max_s,
+ max_k,
+ 0.0,
+ softmax_scale=None,
+ causal=True,
+ )
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as flash attention
+# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ if past_key_values_length > 0 and attention_mask is not None:
+ attention_mask = torch.cat(
+ (
+ torch.full(
+ (input_shape[0], past_key_values_length),
+ True,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ ),
+ attention_mask,
+ ),
+ dim=-1,
+ )
+
+ if attention_mask is not None and torch.all(attention_mask):
+ return None # This uses the faster call when training with full samples
+
+ return attention_mask
+
+
+def replace_llama2_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn(
+ 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.'
+ 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593'
+ )
+
+ LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
+ LlamaAttention.forward = forward
+
+
+def test():
+ from fastchat.train.llama_flash_attn_monkey_patch import \
+ forward as fastchat_forward
+ from transformers.models.llama.configuration_llama import LlamaConfig
+
+ config = LlamaConfig(
+ hidden_size=1024,
+ intermediate_size=128,
+ num_hidden_layers=1,
+ num_attention_heads=8,
+ max_position_embeddings=16,
+ )
+ device = torch.device('cuda')
+ model = LlamaModel(config)
+ attn = LlamaAttention(config).to(device).half()
+ bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
+ -1, seqlen
+ )
+
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
+ for i in range(4):
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
+ if i:
+ mask[0, -i:] = False
+ mask[1, :i] = False
+
+ lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
+ ref, _, _ = attn.forward(
+ hidden, attention_mask=lmask, position_ids=position_ids
+ )
+
+ fast, _, _ = fastchat_forward(
+ attn, hidden, attention_mask=mask, position_ids=position_ids
+ )
+
+ lmask = _prepare_decoder_attention_mask(
+ model, mask, hidden.shape[:2], hidden, 0
+ )
+ test, _, _ = forward(
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
+ )
+
+ print(f'Mean(abs(ref)) = {torch.mean(torch.abs(ref))}')
+ print(f'Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}')
+ print(f'Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}')
+ print(f'Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}')
+ print(f'allclose(fast, test) = {torch.allclose(fast, test)}')
+
+ with torch.no_grad():
+ # Also check that past_kv is handled properly
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
+ part_len = seqlen // 4
+ assert part_len * 4 == seqlen
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
+ mask[0, -2:] = False
+ lmask = _prepare_decoder_attention_mask(
+ model, mask, hidden.shape[:2], hidden, 0
+ )
+ oneshot, _, _ = forward(
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
+ )
+ parts = []
+ past_kv, past_kv_len = None, 0
+ for i in range(4):
+ start = part_len * i
+ end = start + part_len
+ hidden_part = hidden[:, start:end, ...]
+ lmask = _prepare_decoder_attention_mask(
+ model,
+ mask[:, start:end],
+ hidden_part.shape[:2],
+ hidden_part,
+ past_kv_len,
+ )
+ part, _, past_kv = forward(
+ attn,
+ hidden_part.clone(),
+ attention_mask=lmask,
+ position_ids=position_ids[:, start:end],
+ past_key_value=past_kv,
+ use_cache=True,
+ )
+ parts.append(part)
+ past_kv_len = past_kv[0].shape[2]
+
+ print(
+ f'allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}'
+ )
+ print(
+ f'allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}'
+ )
+
+
+if __name__ == '__main__':
+ test()
diff --git a/internvl/patch/llama_flash_attn_monkey_patch.py b/internvl/patch/llama_flash_attn_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8b01d5811ad7860cd00786d83ecc5aafaf82aa4
--- /dev/null
+++ b/internvl/patch/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,222 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import math
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+import transformers
+from torch import nn
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel
+
+ attention_mask: [bsz, q_len]
+ """
+ from einops import rearrange
+ try: # v1
+ from flash_attn.flash_attn_interface import \
+ flash_attn_unpadded_qkvpacked_func
+ except: # v2
+ from flash_attn.flash_attn_interface import \
+ flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+ from flash_attn.bert_padding import pad_input, unpad_input
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ # [bsz, q_len, nh, hd]
+ # [bsz, nh, q_len, hd]
+
+ kv_seq_len = key_states.shape[-2]
+ assert past_key_value is None, 'past_key_value is not supported'
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+ # [bsz, nh, t, hd]
+ assert not output_attentions, 'output_attentions is not supported'
+ assert not use_cache, 'use_cache is not supported'
+
+ # Flash attention codes from
+ # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
+
+ # transform the data into the format required by flash attention
+ qkv = torch.stack(
+ [query_states, key_states, value_states], dim=2
+ ) # [bsz, nh, 3, q_len, hd]
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
+ # We have disabled _prepare_decoder_attention_mask in LlamaModel
+ # the attention_mask should be the same as the key_padding_mask
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ max_s = q_len
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
+ else:
+ nheads = qkv.shape[-2]
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
+ x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(
+ x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads
+ )
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = rearrange(
+ pad_input(
+ rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, bsz, q_len
+ ),
+ 'b s (h d) -> b s h d',
+ h=nheads,
+ )
+ return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def forward_2(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+
+ assert not output_attentions, 'output_attentions is not supported'
+ assert not use_cache, 'use_cache is not supported'
+ assert past_key_value is None, 'past_key_value is not supported'
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+ if self.training:
+ attn_output = F.scaled_dot_product_attention(
+ query_states, key_states, value_states, dropout_p=0.0, is_causal=True
+ )
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
+ f' {attn_weights.size()}'
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
+ )
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
+ f' {attn_output.size()}'
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def replace_llama_attn_with_flash_attn():
+ if hasattr(F, 'scaled_dot_product_attention'):
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_2
+ else:
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ _prepare_decoder_attention_mask
+ )
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/internvl/patch/llama_packed_training_patch.py b/internvl/patch/llama_packed_training_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c915554919ba72f53544744c53f59d9e131c41c
--- /dev/null
+++ b/internvl/patch/llama_packed_training_patch.py
@@ -0,0 +1,106 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import torch
+from flash_attn.flash_attn_interface import flash_attn_varlen_func
+from transformers.models.llama.modeling_llama import (LLAMA_ATTENTION_CLASSES,
+ LlamaFlashAttention2)
+
+
+# Modified from transformers.models.llama.modeling_llama.LlamaFlashAttention2
+class LlamaFlashAttention2ForPackedTraining(LlamaFlashAttention2):
+
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ use_sliding_windows=False,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_sliding_windows (`bool`, *optional*):
+ Whether to activate sliding window attention.
+ """
+ assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
+ query_states = query_states.squeeze(0)
+ key_states = key_states.squeeze(0)
+ value_states = value_states.squeeze(0)
+ cu_seqlens = attention_mask.squeeze(0)
+
+ with torch.no_grad():
+ max_seqlen = max([
+ cu_seqlens[idx+1] - cu_seqlens[idx]
+ for idx in range(cu_seqlens.size(0) - 1)
+ ]).item()
+
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Decide whether to use SWA or not by layer index.
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
+ use_sliding_windows = False
+
+ if not use_sliding_windows:
+ attn_output = flash_attn_varlen_func(
+ q=query_states,
+ k=key_states,
+ v=value_states,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_k=max_seqlen,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output = flash_attn_varlen_func(
+ q=query_states,
+ k=key_states,
+ v=value_states,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_k=max_seqlen,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ query_states = query_states.unsqueeze(0)
+ key_states = key_states.unsqueeze(0)
+ value_states = value_states.unsqueeze(0)
+ return attn_output
+
+
+def replace_llama_attention_class():
+ LLAMA_ATTENTION_CLASSES['flash_attention_2'] = LlamaFlashAttention2ForPackedTraining
+ print('Replace LLAMA_ATTENTION_CLASSES to support packed training!!')
diff --git a/internvl/patch/llama_rmsnorm_monkey_patch.py b/internvl/patch/llama_rmsnorm_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..68efb44468bef51066d6633732b4f15de6daa04c
--- /dev/null
+++ b/internvl/patch/llama_rmsnorm_monkey_patch.py
@@ -0,0 +1,23 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import transformers
+
+
+def replace_llama_rmsnorm_with_fused_rmsnorm():
+ try:
+ from functools import partial
+
+ from apex.normalization import FusedRMSNorm
+ LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
+ print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm')
+ except ImportError:
+ # using the normal LlamaRMSNorm
+ pass
+ except Exception:
+ print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
+ pass
diff --git a/internvl/patch/pad_data_collator.py b/internvl/patch/pad_data_collator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f803ed5f76b5a2f621dcc5fc1242438872558c5c
--- /dev/null
+++ b/internvl/patch/pad_data_collator.py
@@ -0,0 +1,206 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import numpy as np
+import torch
+
+IGNORE_INDEX = -100
+
+
+def pad_data_collator(features, pad_id=0):
+
+ first = features[0]
+ batch = {}
+
+ batch_lens = [feat['input_ids'].shape for feat in features]
+ max_item_length = max(batch_lens)[0]
+ for idx in range(len(features)):
+ feat = features[idx]
+ temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
+ temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
+ feat['input_ids'] = temp_input_ids
+ temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
+ temp_labels[:feat['labels'].shape[0]] = feat['labels']
+ feat['labels'] = temp_labels
+ feat['attention_mask'] = feat['input_ids'].ne(pad_id)
+
+ # Special handling for labels.
+ # Ensure that tensor is created with the correct type
+ # (it should be automatically the case, but let's make sure of it.)
+ if 'label' in first and first['label'] is not None:
+ label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
+ dtype = torch.long if isinstance(label, int) else torch.float
+ batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
+ elif 'label_ids' in first and first['label_ids'] is not None:
+ if isinstance(first['label_ids'], torch.Tensor):
+ batch['labels'] = torch.stack([f['label_ids'] for f in features])
+ else:
+ dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
+ batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
+
+ # Handling of all other possible keys.
+ # Again, we will use the first element to figure out which key/values are not None for this model.
+ for k, v in first.items():
+ if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str):
+ if isinstance(v, torch.Tensor):
+ batch[k] = torch.stack([f[k] for f in features])
+ elif isinstance(v, np.ndarray):
+ batch[k] = torch.tensor(np.stack([f[k] for f in features]))
+ else:
+ batch[k] = torch.tensor([f[k] for f in features])
+ return batch
+
+
+# def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
+
+# first = features[0]
+# batch = {}
+
+# batch_lens = [feat['input_ids'].shape for feat in features]
+# max_item_length = max_item_length or max(batch_lens)[0]
+# for idx in range(len(features)):
+# feat = features[idx]
+# temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
+# temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
+# feat['input_ids'] = temp_input_ids
+# temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
+# temp_labels[:feat['labels'].shape[0]] = feat['labels']
+# feat['labels'] = temp_labels
+# feat['attention_mask'] = feat['input_ids'].ne(pad_id)
+
+# if 'position_ids' in feat:
+# temp_position_ids = torch.LongTensor([pad_id] * max_item_length)
+# temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids']
+# feat['position_ids'] = temp_position_ids
+
+# if 'loss_weight' in feat:
+# temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length)
+# temp_loss_weight[:feat['loss_weight'].shape[0]] = feat['loss_weight']
+# feat['loss_weight'] = temp_loss_weight
+
+# # Special handling for labels.
+# # Ensure that tensor is created with the correct type
+# # (it should be automatically the case, but let's make sure of it.)
+# if 'label' in first and first['label'] is not None:
+# label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
+# dtype = torch.long if isinstance(label, int) else torch.float
+# batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
+# elif 'label_ids' in first and first['label_ids'] is not None:
+# if isinstance(first['label_ids'], torch.Tensor):
+# batch['labels'] = torch.stack([f['label_ids'] for f in features])
+# else:
+# dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
+# batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
+
+# # Handling of all other possible keys.
+# # Again, we will use the first element to figure out which key/values are not None for this model.
+# for k, v in first.items():
+# if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \
+# v is not None and not isinstance(v, str):
+# if isinstance(v, torch.Tensor):
+# batch[k] = torch.stack([f[k] for f in features])
+# elif isinstance(v, np.ndarray):
+# batch[k] = torch.tensor(np.stack([f[k] for f in features]))
+# else:
+# batch[k] = torch.tensor([f[k] for f in features])
+# if k in ('pixel_values', 'image_flags'):
+# if isinstance(v, torch.Tensor):
+# batch[k] = torch.concat([f[k] for f in features])
+# elif isinstance(v, np.ndarray):
+# batch[k] = torch.concat(np.stack([f[k] for f in features]))
+# else:
+# batch[k] = torch.concat([f[k] for f in features])
+# return batch
+
+def concat_pad_data_collator(features, pad_id=0):
+ # import pdb
+ # pdb.set_trace()
+ first = features[0]
+ batch = {}
+ batch_lens = [len(feat['input_ids']) for feat in features]
+ max_item_length = max(batch_lens)
+ for idx in range(len(features)):
+ feat = features[idx]
+ temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
+ temp_input_ids[:len(feat['input_ids'])] = torch.Tensor(feat['input_ids'])
+ feat['input_ids'] = temp_input_ids
+ # temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
+ # temp_labels[:feat['labels'].shape[0]] = feat['labels']
+ # feat['labels'] = temp_labels
+ feat['attention_mask'] = feat['input_ids'].ne(pad_id)
+
+ # Special handling for labels.
+ # Ensure that tensor is created with the correct type
+ # (it should be automatically the case, but let's make sure of it.)
+ if 'label' in first and first['label'] is not None:
+ label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
+ dtype = torch.long if isinstance(label, int) else torch.float
+ batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
+ elif 'label_ids' in first and first['label_ids'] is not None:
+ if isinstance(first['label_ids'], torch.Tensor):
+ batch['labels'] = torch.stack([f['label_ids'] for f in features])
+ else:
+ dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
+ batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
+
+ # Handling of all other possible keys.
+ # Again, we will use the first element to figure out which key/values are not None for this model.
+ for k, v in first.items():
+ if k not in ('label', 'label_ids', 'pixel_values', 'image_flags', 'mask_values', 'masks_flags') and \
+ v is not None and not isinstance(v, str):
+ if isinstance(v, torch.Tensor):
+ batch[k] = torch.stack([f[k] for f in features])
+ elif isinstance(v, np.ndarray):
+ batch[k] = torch.tensor(np.stack([f[k] for f in features]))
+ else:
+ batch[k] = torch.tensor([f[k] for f in features])
+ if k in ('pixel_values', 'image_flags', 'mask_values', 'masks_flags'):
+ if isinstance(v, torch.Tensor):
+ batch[k] = torch.concat([f[k] for f in features])
+ elif isinstance(v, np.ndarray):
+ batch[k] = torch.concat(np.stack([f[k] for f in features]))
+ else:
+ batch[k] = torch.concat([f[k] for f in features])
+ return batch
+
+
+def dpo_concat_pad_data_collator(features, pad_id=0):
+
+ first = features[0]
+ batch = {}
+
+ for prefix in ['chosen_', 'rejected_']:
+ batch_lens = [feat[f'{prefix}input_ids'].shape[0] for feat in features]
+ max_item_length = max(batch_lens)
+ for idx in range(len(features)):
+ feat = features[idx]
+ temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
+ temp_input_ids[:feat[f'{prefix}input_ids'].shape[0]] = feat[f'{prefix}input_ids']
+ feat[f'{prefix}input_ids'] = temp_input_ids
+ temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
+ temp_labels[:feat[f'{prefix}labels'].shape[0]] = feat[f'{prefix}labels']
+ feat[f'{prefix}labels'] = temp_labels
+ feat[f'{prefix}attention_mask'] = feat[f'{prefix}input_ids'].ne(pad_id)
+
+ # Handling of all other possible keys.
+ # Again, we will use the first element to figure out which key/values are not None for this model.
+ for k, v in first.items():
+ if k not in ('pixel_values', 'image_flags') and \
+ v is not None and not isinstance(v, str):
+ if isinstance(v, torch.Tensor):
+ batch[k] = torch.stack([f[k] for f in features])
+ elif isinstance(v, np.ndarray):
+ batch[k] = torch.tensor(np.stack([f[k] for f in features]))
+ else:
+ batch[k] = torch.tensor([f[k] for f in features])
+ if k in ('pixel_values', 'image_flags'):
+ if isinstance(v, torch.Tensor):
+ batch[k] = torch.concat([f[k] for f in features])
+ elif isinstance(v, np.ndarray):
+ batch[k] = torch.concat(np.stack([f[k] for f in features]))
+ else:
+ batch[k] = torch.concat([f[k] for f in features])
+ return batch
diff --git a/internvl/patch/phi3_packed_training_patch.py b/internvl/patch/phi3_packed_training_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cdb60e85b37006a288e5cc205f8d75774acb092
--- /dev/null
+++ b/internvl/patch/phi3_packed_training_patch.py
@@ -0,0 +1,105 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import torch
+from flash_attn.flash_attn_interface import flash_attn_varlen_func
+from internvl.model.phi3.modeling_phi3 import (PHI3_ATTENTION_CLASSES,
+ Phi3FlashAttention2)
+
+
+class Phi3FlashAttention2ForPackedTraining(Phi3FlashAttention2):
+
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ use_sliding_windows=False,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`float`):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_sliding_windows (`bool`, *optional*):
+ Whether to activate sliding window attention.
+ """
+ assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
+ query_states = query_states.squeeze(0)
+ key_states = key_states.squeeze(0)
+ value_states = value_states.squeeze(0)
+ cu_seqlens = attention_mask.squeeze(0)
+
+ with torch.no_grad():
+ max_seqlen = max([
+ cu_seqlens[idx+1] - cu_seqlens[idx]
+ for idx in range(cu_seqlens.size(0) - 1)
+ ]).item()
+
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Decide whether to use SWA or not by layer index.
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
+ use_sliding_windows = False
+
+ if not use_sliding_windows:
+ attn_output = flash_attn_varlen_func(
+ q=query_states,
+ k=key_states,
+ v=value_states,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_k=max_seqlen,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output = flash_attn_varlen_func(
+ q=query_states,
+ k=key_states,
+ v=value_states,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_k=max_seqlen,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ query_states = query_states.unsqueeze(0)
+ key_states = key_states.unsqueeze(0)
+ value_states = value_states.unsqueeze(0)
+ return attn_output
+
+
+def replace_phi3_attention_class():
+ PHI3_ATTENTION_CLASSES['flash_attention_2'] = Phi3FlashAttention2ForPackedTraining
+ print('Replace PHI3_ATTENTION_CLASSES to support packed training!!')
diff --git a/internvl/patch/qwen2_packed_training_patch.py b/internvl/patch/qwen2_packed_training_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b55888f0b2da0571d6c47a9cc7c3906fa918281
--- /dev/null
+++ b/internvl/patch/qwen2_packed_training_patch.py
@@ -0,0 +1,106 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import torch
+from flash_attn.flash_attn_interface import flash_attn_varlen_func
+from transformers.models.qwen2.modeling_qwen2 import (QWEN2_ATTENTION_CLASSES,
+ Qwen2FlashAttention2)
+
+
+# Modified from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2
+class Qwen2FlashAttention2ForPackedTraining(Qwen2FlashAttention2):
+
+ def _flash_attention_forward(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ dropout=0.0,
+ softmax_scale=None,
+ use_sliding_windows=False,
+ ):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ dropout (`int`, *optional*):
+ Attention dropout
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_sliding_windows (`bool`, *optional*):
+ Whether to activate sliding window attention.
+ """
+ assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
+ query_states = query_states.squeeze(0)
+ key_states = key_states.squeeze(0)
+ value_states = value_states.squeeze(0)
+ cu_seqlens = attention_mask.squeeze(0)
+
+ with torch.no_grad():
+ max_seqlen = max([
+ cu_seqlens[idx+1] - cu_seqlens[idx]
+ for idx in range(cu_seqlens.size(0) - 1)
+ ]).item()
+
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
+ causal = self.is_causal and query_length != 1
+
+ # Decide whether to use SWA or not by layer index.
+ if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
+ use_sliding_windows = False
+
+ if not use_sliding_windows:
+ attn_output = flash_attn_varlen_func(
+ q=query_states,
+ k=key_states,
+ v=value_states,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_k=max_seqlen,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )
+ else:
+ attn_output = flash_attn_varlen_func(
+ q=query_states,
+ k=key_states,
+ v=value_states,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_k=max_seqlen,
+ dropout_p=dropout,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=(self.config.sliding_window, self.config.sliding_window),
+ )
+
+ query_states = query_states.unsqueeze(0)
+ key_states = key_states.unsqueeze(0)
+ value_states = value_states.unsqueeze(0)
+ return attn_output
+
+
+def replace_qwen2_attention_class():
+ QWEN2_ATTENTION_CLASSES['flash_attention_2'] = Qwen2FlashAttention2ForPackedTraining
+ print('Replace QWEN2_ATTENTION_CLASSES to support packed training!!')
diff --git a/internvl/patch/train_dataloader_patch.py b/internvl/patch/train_dataloader_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..eade21e51ca8011f379eca5c4a20f41b31e4db58
--- /dev/null
+++ b/internvl/patch/train_dataloader_patch.py
@@ -0,0 +1,53 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import datasets
+import torch
+import transformers
+from torch.utils.data import DataLoader
+from transformers.trainer import is_datasets_available, seed_worker
+
+
+def get_train_dataloader(self) -> DataLoader:
+ """
+ Returns the training [`~torch.utils.data.DataLoader`].
+
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+ training if necessary) otherwise.
+
+ Subclass and override this method if you want to inject some custom behavior.
+ """
+ if self.train_dataset is None:
+ raise ValueError('Trainer: training requires a train_dataset.')
+
+ train_dataset = self.train_dataset
+ data_collator = self.data_collator
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+ train_dataset = self._remove_unused_columns(train_dataset, description='training')
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
+
+ dataloader_params = {
+ 'batch_size': self._train_batch_size,
+ 'collate_fn': data_collator,
+ 'num_workers': self.args.dataloader_num_workers,
+ 'pin_memory': self.args.dataloader_pin_memory,
+ 'persistent_workers': self.args.dataloader_persistent_workers,
+ }
+
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
+ dataloader_params['sampler'] = self._get_train_sampler()
+ dataloader_params['drop_last'] = self.args.dataloader_drop_last
+ dataloader_params['worker_init_fn'] = seed_worker
+
+ if self.args.use_packed_ds:
+ return DataLoader(train_dataset, **dataloader_params)
+ return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
+
+
+def replace_train_dataloader():
+ transformers.Trainer.get_train_dataloader = get_train_dataloader
+ # print('Replace train dataloader!!')
diff --git a/internvl/patch/train_sampler_patch.py b/internvl/patch/train_sampler_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..35d616e055052108c246f09ebe24f692bfb2eb8d
--- /dev/null
+++ b/internvl/patch/train_sampler_patch.py
@@ -0,0 +1,125 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+from typing import List, Optional
+
+import torch
+import transformers
+from torch.utils.data import Dataset, Sampler
+from transformers.tokenization_utils_base import BatchEncoding
+from transformers.trainer import (LengthGroupedSampler, RandomSampler,
+ has_length)
+from transformers.trainer_pt_utils import logger
+
+
+# copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38
+def split_to_even_chunks(indices, lengths, num_chunks):
+ """
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
+ """
+
+ if len(indices) % num_chunks != 0:
+ return [indices[i::num_chunks] for i in range(num_chunks)]
+
+ num_indices_per_chunk = len(indices) // num_chunks
+
+ chunks = [[] for _ in range(num_chunks)]
+ chunks_lengths = [0 for _ in range(num_chunks)]
+ for index in indices:
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
+ chunks[shortest_chunk].append(index)
+ chunks_lengths[shortest_chunk] += lengths[index]
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
+ chunks_lengths[shortest_chunk] = float('inf')
+
+ return chunks
+
+
+# copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88
+def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+# modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ world_size: int,
+ dataset: Optional[Dataset] = None,
+ lengths: Optional[List[int]] = None,
+ model_input_name: Optional[str] = None,
+ generator=None,
+ ):
+ if dataset is None and lengths is None:
+ raise ValueError('One of dataset and lengths must be provided.')
+
+ self.batch_size = batch_size
+ if lengths is None:
+ model_input_name = model_input_name if model_input_name is not None else 'input_ids'
+ if (
+ not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
+ or model_input_name not in dataset[0]
+ ):
+ raise ValueError(
+ 'Can only automatically infer lengths for datasets whose items are dictionaries with an '
+ f"'{model_input_name}' key."
+ )
+ lengths = [len(feature[model_input_name]) for feature in dataset]
+ elif isinstance(lengths, torch.Tensor):
+ logger.info(
+ 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...'
+ )
+ lengths = lengths.tolist()
+ self.world_size = world_size
+ self.lengths = lengths
+ self.generator = generator
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ return iter(indices)
+
+
+# patch trainer
+def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+ # Build the sampler.
+ if self.args.group_by_length:
+ lengths = []
+ for dataset in self.train_dataset.datasets:
+ lengths = lengths + dataset.length
+ model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
+ return LengthGroupedSampler(
+ self.args.train_batch_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
+ # self.args.train_batch_size * self.args.gradient_accumulation_steps,
+ dataset=self.train_dataset,
+ lengths=lengths,
+ model_input_name=model_input_name,
+ )
+ else:
+ return RandomSampler(self.train_dataset)
+
+
+def replace_train_sampler():
+ transformers.Trainer._get_train_sampler = _get_train_sampler
+ # print('Replace train sampler!!')
diff --git a/internvl/train/__init__.py b/internvl/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e207feb58fdcb609895b072fae31371385c522a7
--- /dev/null
+++ b/internvl/train/__init__.py
@@ -0,0 +1,5 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
diff --git a/internvl/train/__pycache__/__init__.cpython-310.pyc b/internvl/train/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0fdfa4701c7b7fb6e7225b3b00f4652e12b7c3aa
Binary files /dev/null and b/internvl/train/__pycache__/__init__.cpython-310.pyc differ
diff --git a/internvl/train/__pycache__/__init__.cpython-39.pyc b/internvl/train/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..98184b3eb624f02fd106965a292e21eac278ed98
Binary files /dev/null and b/internvl/train/__pycache__/__init__.cpython-39.pyc differ
diff --git a/internvl/train/__pycache__/constants.cpython-310.pyc b/internvl/train/__pycache__/constants.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..939dae928a517aa0a11a187ec900646395da3e6a
Binary files /dev/null and b/internvl/train/__pycache__/constants.cpython-310.pyc differ
diff --git a/internvl/train/__pycache__/constants.cpython-39.pyc b/internvl/train/__pycache__/constants.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c6f87104aa41c876676d2a86ba0aaf5d9330c58f
Binary files /dev/null and b/internvl/train/__pycache__/constants.cpython-39.pyc differ
diff --git a/internvl/train/__pycache__/dataset.cpython-310.pyc b/internvl/train/__pycache__/dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b80205294b6c2eb6d5525bb32dd36b6b0d0a3331
Binary files /dev/null and b/internvl/train/__pycache__/dataset.cpython-310.pyc differ
diff --git a/internvl/train/__pycache__/dataset.cpython-39.pyc b/internvl/train/__pycache__/dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c7ad06cc1e5372a37aa1c6af4a8101f435a16d53
Binary files /dev/null and b/internvl/train/__pycache__/dataset.cpython-39.pyc differ
diff --git a/internvl/train/__pycache__/dataset_packed.cpython-39.pyc b/internvl/train/__pycache__/dataset_packed.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..24f3f57af4f1b31784a981640cf2a31b82a3d1cd
Binary files /dev/null and b/internvl/train/__pycache__/dataset_packed.cpython-39.pyc differ
diff --git a/internvl/train/__pycache__/trainer_monkey_patch.cpython-39.pyc b/internvl/train/__pycache__/trainer_monkey_patch.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5150c2b2fb90e37f173558f2e14f16bb9923d82
Binary files /dev/null and b/internvl/train/__pycache__/trainer_monkey_patch.cpython-39.pyc differ
diff --git a/internvl/train/constants.py b/internvl/train/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..86f6c967b1b5071b140b2e16bbd4d352d9373204
--- /dev/null
+++ b/internvl/train/constants.py
@@ -0,0 +1,21 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+IMG_CONTEXT_TOKEN = ''
+IMG_START_TOKEN = '
'
+IMG_END_TOKEN = ''
+QUAD_START_TOKEN = ''
+QUAD_END_TOKEN = ''
+REF_START_TOKEN = '['
+REF_END_TOKEN = ']'
+BOX_START_TOKEN = ''
+BOX_END_TOKEN = ''
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
+CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
+SIGLIP_MEAN = (0.5, 0.5, 0.5)
+SIGLIP_STD = (0.5, 0.5, 0.5)
diff --git a/internvl/train/dataset.py b/internvl/train/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..56cfa77518f9fc1ed96de9300d98c932f1a55ffa
--- /dev/null
+++ b/internvl/train/dataset.py
@@ -0,0 +1,921 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import io
+import matplotlib.pyplot as plt
+from transformers.trainer_pt_utils import LabelSmoother
+
+IGNORE_TOKEN_ID = LabelSmoother.ignore_index
+import os
+import random
+import re
+from collections import Counter
+from typing import Dict
+
+import cv2
+import imageio
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision.transforms as T
+import transformers
+from decord import VideoReader
+from internvl.conversation import get_conv_template
+from PIL import Image
+from torch.utils.data import ConcatDataset, WeightedRandomSampler
+from torchvision.transforms.functional import InterpolationMode
+
+from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD,
+ IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN,
+ SIGLIP_MEAN, SIGLIP_STD)
+
+try:
+ from petrel_client.client import Client
+ from petrel_client.common.config import Config
+except ImportError as E:
+ print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.')
+import sys
+
+
+def calculate_ngram_repetition(text, n):
+ words = text.split()
+ ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)]
+ ngram_counts = Counter(ngrams)
+ total_ngrams = len(ngrams)
+ repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1)
+ return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0
+
+
+def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10):
+ for conversation in conversations:
+ if conversation['from'] == 'gpt':
+ model_answer = conversation['value']
+ repeat_ratio = calculate_ngram_repetition(model_answer, ngram)
+ if repeat_ratio > repeat_threshold:
+ raise Exception
+
+
+def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
+ if sample in ['rand', 'middle']: # uniform sampling
+ acc_samples = min(num_frames, vlen)
+ # split the video into `acc_samples` intervals, and sample from each interval.
+ intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
+ ranges = []
+ for idx, interv in enumerate(intervals[:-1]):
+ ranges.append((interv, intervals[idx + 1] - 1))
+ if sample == 'rand':
+ try:
+ frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
+ except:
+ frame_indices = np.random.permutation(vlen)[:acc_samples]
+ frame_indices.sort()
+ frame_indices = list(frame_indices)
+ elif fix_start is not None:
+ frame_indices = [x[0] + fix_start for x in ranges]
+ elif sample == 'middle':
+ frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
+ else:
+ raise NotImplementedError
+
+ if len(frame_indices) < num_frames: # padded with last frame
+ padded_frame_indices = [frame_indices[-1]] * num_frames
+ padded_frame_indices[:len(frame_indices)] = frame_indices
+ frame_indices = padded_frame_indices
+ elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
+ output_fps = float(sample[3:])
+ duration = float(vlen) / input_fps
+ delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
+ frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
+ frame_indices = np.around(frame_seconds * input_fps).astype(int)
+ frame_indices = [e for e in frame_indices if e < vlen]
+ if max_num_frames > 0 and len(frame_indices) > max_num_frames:
+ frame_indices = frame_indices[:max_num_frames]
+ # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
+ else:
+ raise ValueError
+ return frame_indices
+
+
+def read_frames_gif(
+ video_path, num_frames, sample='rand', fix_start=None,
+ client=None, min_num_frames=4
+):
+ if 's3://' in video_path:
+ video_bytes = client.get(video_path)
+ gif = imageio.get_reader(io.BytesIO(video_bytes))
+ else:
+ gif = imageio.get_reader(video_path)
+ vlen = len(gif)
+
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
+ frame_indices = get_frame_indices(
+ t_num_frames, vlen, sample=sample, fix_start=fix_start
+ )
+ frames = []
+ for index, frame in enumerate(gif):
+ if index in frame_indices:
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8)
+ frame = Image.fromarray(frame)
+ frames.append(frame)
+ return frames
+
+
+def read_frames_decord(
+ video_path, num_frames, sample='rand', fix_start=None,
+ client=None, clip=None, min_num_frames=4
+):
+ if 's3://' in video_path:
+ video_bytes = client.get(video_path)
+ video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
+ else:
+ video_reader = VideoReader(video_path, num_threads=1)
+ vlen = len(video_reader)
+ fps = video_reader.get_avg_fps()
+ duration = vlen / float(fps)
+ if clip:
+ start, end = clip
+ duration = end - start
+ vlen = int(duration * fps)
+ start_index = int(start * fps)
+
+ # t_num_frames = min(max(int(duration * sample_fps), min_num_frames), num_frames)
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
+
+ frame_indices = get_frame_indices(
+ t_num_frames, vlen, sample=sample, fix_start=fix_start,
+ input_fps=fps
+ )
+ if clip:
+ frame_indices = [f + start_index for f in frame_indices]
+ frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
+ frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
+ return frames
+
+
+def extract_frame_number(filename):
+ # Extract the numeric part from the filename using regular expressions
+ match = re.search(r'_(\d+).jpg$', filename)
+ return int(match.group(1)) if match else -1
+
+
+def sort_frames(frame_paths):
+ # Extract filenames from each path and sort by their numeric part
+ return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
+
+
+def read_frames_folder(
+ video_path, num_frames, sample='rand', fix_start=None,
+ client=None, clip=None, min_num_frames=4
+):
+ if 's3://' in video_path:
+ image_list = sort_frames(client.list(video_path))
+ frames = []
+ for image in image_list:
+ fp = os.path.join(video_path, image)
+ frame = Image.open(io.BytesIO(client.get(fp)))
+ frames.append(frame)
+ else:
+ image_list = sort_frames(list(os.listdir(video_path)))
+ frames = []
+ for image in image_list:
+ fp = os.path.join(video_path, image)
+ frame = Image.open(fp).convert('RGB')
+ frames.append(frame)
+ vlen = len(frames)
+
+ t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
+
+ if vlen > t_num_frames:
+ frame_indices = get_frame_indices(
+ t_num_frames, vlen, sample=sample, fix_start=fix_start
+ )
+ frames = [frames[i] for i in frame_indices]
+ return frames
+
+
+class WeightedConcatDataset(ConcatDataset):
+ def __init__(self, datasets, weights):
+ super().__init__(datasets)
+ self.weights = torch.DoubleTensor(weights)
+ self.total_size = sum(len(d) for d in datasets)
+ self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True)
+
+ def __iter__(self):
+ return iter(self.sampler)
+
+ def __len__(self):
+ return self.total_size
+
+
+def pil_loader(img_str):
+ buff = io.BytesIO(img_str)
+ img = Image.open(buff)
+ return img.convert('RGB')
+
+
+class TCSLoader(object):
+
+ def __init__(self, conf_path, sc_config_key='sensecore'):
+ print(f'[TCSLoader] config_path: {conf_path}')
+ print('--> before Client(conf_path)')
+ self.client = Client(conf_path)
+ self.sc_config_key = sc_config_key
+ print('--> after Client(conf_path)')
+
+ def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', clip=None):
+ if image_type == 'image':
+ img_value_str = self.client.get(fn)
+ img = pil_loader(img_value_str)
+ return img
+
+ elif image_type == 'video':
+ if fn.endswith('/'):
+ frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
+ client=self.client, sample=sample)
+ elif fn.endswith('.gif'):
+ frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
+ client=self.client, sample=sample)
+ else:
+ frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
+ client=self.client, sample=sample, clip=clip)
+ return frames
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+def simulate_jpeg_degradation(quality):
+ def jpeg_degrade(img):
+ with io.BytesIO() as output:
+ img.convert('RGB').save(output, format='JPEG', quality=quality)
+ output.seek(0) # Move the reading cursor to the start of the stream
+ img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory
+ return img_jpeg
+ return jpeg_degrade
+
+
+# Define the JPEG compression quality range, pre-create all JPEG compression functions
+qualities = list(range(75, 101))
+jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}
+
+
+def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'):
+ if normalize_type == 'imagenet':
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
+ elif normalize_type == 'clip':
+ MEAN, STD = CLIP_MEAN, CLIP_STD
+ elif normalize_type == 'siglip':
+ MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
+ else:
+ raise NotImplementedError
+ if is_train: # use data augumentation
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+ else:
+ if pad2square is False: # now we use this transform function by default
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+ else:
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))),
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+
+ return transform
+
+
+def preprocess(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ conv = get_conv_template(template_name)
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]['from']] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence['from']]
+ assert role == conv.roles[j % 2], f'{i}'
+ conv.append_message(role, sentence['value'])
+ conversations.append(conv.get_prompt())
+
+ if not text_only:
+ new_conversations = []
+ for conversation in conversations:
+ for i in range(num_image):
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
+ conversation = conversation.replace('', image_tokens, 1)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors='pt',
+ padding=False if group_by_length or use_packed_ds else 'max_length',
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
+
+ # Mask targets. Only compute loss on the assistant outputs.
+ sep = conv.sep + conv.roles[1] + ': '
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ turns = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_TOKEN_ID
+ for i, turn in enumerate(turns):
+ if turn == '':
+ break
+ turn_len = len(tokenizer(turn).input_ids)
+
+ parts = turn.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ if i != 0 and not tokenizer.legacy:
+ # The legacy and non-legacy modes handle special tokens differently
+ instruction_len -= 1
+
+ # Ignore the user instructions
+ target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
+ cur_len += turn_len
+
+ if i != 0 and not tokenizer.legacy:
+ # The legacy and non-legacy modes handle special tokens differently
+ cur_len -= 1
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+
+ if False: # Inspect and check the correctness of masking
+ z = target.clone()
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
+ logger.info(tokenizer.decode(z))
+ exit()
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ print(
+ f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
+ f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
+ )
+ sys.stdout.flush()
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def preprocess_mpt(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ conv = get_conv_template(template_name)
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]['from']] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence['from']]
+ assert role == conv.roles[j % 2], f'{i}'
+ conv.append_message(role, sentence['value'])
+ conversations.append(conv.get_prompt())
+
+ if not text_only:
+ new_conversations = []
+ for conversation in conversations:
+ for i in range(num_image):
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
+ conversation = conversation.replace('', image_tokens, 1)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors='pt',
+ padding=False if group_by_length or use_packed_ds else 'max_length',
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ # Mask targets. Only compute loss on the assistant outputs.
+ sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ turns = conversation.split(conv.sep)
+ re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
+ for conv_idx in range(3, len(turns), 2):
+ re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
+ cur_len = 0
+ target[:cur_len] = IGNORE_TOKEN_ID
+ for i, turn in enumerate(re_turns):
+ if turn == '':
+ break
+ turn_len = len(tokenizer(turn).input_ids) + 1
+
+ parts = turn.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ instruction_len = len(tokenizer(parts[0]).input_ids)
+
+ # Ignore the user instructions
+ target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
+ # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
+ # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
+ # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
+ cur_len += turn_len
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ print(
+ f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
+ f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
+ )
+ sys.stdout.flush()
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def preprocess_phi3(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ conv = get_conv_template(template_name)
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]['from']] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence['from']]
+ assert role == conv.roles[j % 2], f'{i}'
+ conv.append_message(role, sentence['value'])
+ conversations.append(conv.get_prompt())
+
+ if not text_only:
+ new_conversations = []
+ for conversation in conversations:
+ for i in range(num_image):
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
+ conversation = conversation.replace('', image_tokens, 1)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ # Tokenize conversations
+ tokenizer.padding_side = 'right'
+ input_ids = tokenizer(
+ conversations,
+ return_tensors='pt',
+ padding=False if group_by_length or use_packed_ds else 'max_length',
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ # Mask targets. Only compute loss on the assistant outputs.
+ sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|>
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(int(tokenizer.pad_token_id)).sum())
+
+ turns = conversation.split(conv.sep)
+ re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
+ for conv_idx in range(3, len(turns), 2):
+ re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
+ cur_len = 1
+ target[:cur_len] = IGNORE_TOKEN_ID
+ endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
+ target[target == endoftext_id] = IGNORE_TOKEN_ID
+
+ for i, turn in enumerate(re_turns):
+ if turn == '':
+ break
+ if i == 0:
+ turn_len = len(tokenizer(turn).input_ids)
+ else:
+ turn_len = len(tokenizer(turn).input_ids) - 1
+ parts = turn.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if i == 0:
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 1
+ else:
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ # Ignore the user instructions
+ target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
+ # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
+ # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
+ # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
+ cur_len += turn_len
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+
+ if False: # Inspect and check the correctness of masking
+ z = target.clone()
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
+ print(repr(tokenizer.decode(z)))
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ print(
+ f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
+ f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
+ )
+ sys.stdout.flush()
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def preprocess_internlm(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ conv = get_conv_template(template_name)
+ roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]['from']] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence['from']]
+ assert role == conv.roles[j % 2], f'{i}'
+ sentence['value'] = sentence['value'].strip()
+ conv.append_message(role, sentence['value'])
+ conversations.append(conv.get_prompt())
+
+ if not text_only:
+ new_conversations = []
+ for conversation in conversations:
+ for i in range(num_image):
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
+ conversation = conversation.replace('', image_tokens, 1)
+ new_conversations.append(conversation)
+ conversations = new_conversations
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ conversations,
+ return_tensors='pt',
+ padding=False if group_by_length or use_packed_ds else 'max_length',
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+ targets = input_ids.clone()
+
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id
+ cur_len = 1
+ target[:cur_len] = IGNORE_TOKEN_ID #
+ parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n
+ info = parts[0] + conv.roles[1]
+ temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的
+ target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
+ cur_len = cur_len + temp_len
+
+ for index in range(1, len(parts) - 1):
+ info = parts[index]
+ part1, part2 = info.split(conv.roles[0])
+ temp_len = len(tokenizer(part1).input_ids) - 1
+ cur_len = cur_len + temp_len
+ part = conv.roles[0] + part2 + conv.roles[1]
+ temp_len = len(tokenizer(part).input_ids) - 1
+ target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
+ cur_len = cur_len + temp_len
+ last_info = parts[-1]
+ temp_len = len(tokenizer(last_info).input_ids) - 1
+ cur_len = cur_len + temp_len
+
+ target[cur_len:] = IGNORE_TOKEN_ID
+ if False: # Inspect and check the correctness of masking
+ z = target.clone()
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
+ print(repr(tokenizer.decode(z)))
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_TOKEN_ID
+ print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.')
+ sys.stdout.flush()
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def preprocess_internvl2_5(
+ template_name,
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ num_image_token_list: list,
+ text_only: bool = False,
+ group_by_length: bool = False,
+ use_packed_ds: bool = False,
+ ds_name: str = None,
+ num_image: int = 1
+) -> Dict:
+ assert len(sources) == 1, 'process only the first conversations'
+ conversations = sources[0]
+
+ if conversations[0]['from'] == 'system':
+ system_prompt = conversations[0]['value']
+ conversations = conversations[1:] # remove system prompt
+ else:
+ conv = get_conv_template(template_name)
+ system_prompt = conv.system_message
+ # system_prompt = None
+
+ if not text_only:
+ new_conversations = []
+ current_image_idx = 0
+ for conversation in conversations:
+ if conversation['from'] == 'human':
+ image_cnt = conversation['value'].count('')
+ for i in range(image_cnt):
+ if current_image_idx == num_image:
+ break
+ image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}'
+ conversation['value'] = conversation['value'].replace('', image_tokens, 1)
+ current_image_idx += 1
+ new_conversations.append(conversation)
+ conversations = new_conversations
+ assert current_image_idx == num_image, f'{current_image_idx} != {num_image}'
+
+ batches, roles = [], []
+ if system_prompt is not None:
+ batches.append(f'<|im_start|>system\n{system_prompt}<|im_end|>\n')
+ roles.append('system')
+ for conversation in conversations:
+ if conversation['from'] == 'human':
+ batches.append(f'<|im_start|>user\n{conversation["value"]}<|im_end|>\n')
+ roles.append('human')
+ elif conversation['from'] == 'gpt':
+ batches.append(f'<|im_start|>assistant\n{conversation["value"]}<|im_end|>\n')
+ roles.append('gpt')
+ else:
+ raise NotImplementedError
+
+ add_bos_token = getattr(tokenizer, 'add_bos_token', False)
+ if add_bos_token: # for InternLM series
+ batches[0] = tokenizer.bos_token + batches[0]
+
+ # Tokenize conversations
+ input_ids = tokenizer(
+ batches,
+ return_tensors='np',
+ padding=False,
+ max_length=tokenizer.model_max_length,
+ truncation=False,
+ ).input_ids
+
+ if add_bos_token: # for InternLM series
+ input_ids = [item[1:] for item in input_ids]
+
+ final_input_ids, final_targets = [], []
+ ignore_ids = tokenizer('<|im_start|>assistant\n', return_tensors='np').input_ids[0]
+ ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0]
+ for role, input_id in zip(roles, input_ids):
+ final_input_ids.append(input_id)
+ if role == 'system' or role == 'human':
+ final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID)) # ignore
+ elif role == 'gpt':
+ target = input_id.copy()
+ target[:ignore_len] = IGNORE_TOKEN_ID # ignore loss for `<|im_start|>assistant\n`
+ target[-1:] = IGNORE_TOKEN_ID # ignore loss for `\n`
+ final_targets.append(target)
+ else:
+ raise NotImplementedError
+ input_ids = torch.tensor(np.concatenate(final_input_ids))[:tokenizer.model_max_length]
+ targets = torch.tensor(np.concatenate(final_targets))[:tokenizer.model_max_length]
+
+ padding = False if group_by_length or use_packed_ds else True
+ if padding:
+ current_length = input_ids.size(0)
+ padding_length = tokenizer.model_max_length - current_length
+ input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_token_id)
+ targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID)
+
+ input_ids = input_ids.unsqueeze(0)
+ targets = targets.unsqueeze(0)
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ attention_mask=input_ids.ne(tokenizer.pad_token_id),
+ )
+
+
+def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
+ best_ratio_diff = float('inf')
+ best_ratio = (1, 1)
+ area = width * height
+ for ratio in target_ratios:
+ target_aspect_ratio = ratio[0] / ratio[1]
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
+ if ratio_diff < best_ratio_diff:
+ best_ratio_diff = ratio_diff
+ best_ratio = ratio
+ elif ratio_diff == best_ratio_diff:
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
+ best_ratio = ratio
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
+ return best_ratio
+
+
+def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, return_ratio=False):
+ orig_width, orig_height = image.size
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = set(
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
+ i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # resize the image
+ resized_img = image.resize((target_width, target_height))
+ processed_images = []
+ for i in range(blocks):
+ box = (
+ (i % (target_width // image_size)) * image_size,
+ (i // (target_width // image_size)) * image_size,
+ ((i % (target_width // image_size)) + 1) * image_size,
+ ((i // (target_width // image_size)) + 1) * image_size
+ )
+ # split the image
+ split_img = resized_img.crop(box)
+ processed_images.append(split_img)
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = image.resize((image_size, image_size))
+ processed_images.append(thumbnail_img)
+ if return_ratio:
+ return processed_images, target_aspect_ratio
+ return processed_images
+
+
+def dynamic_preprocess_mask(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
+ # import pdb
+ length, orig_height, orig_width = image.shape
+ aspect_ratio = orig_width / orig_height
+
+ # calculate the existing image aspect ratio
+ target_ratios = set(
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
+ i * j <= max_num and i * j >= min_num)
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
+
+ # find the closest aspect ratio to the target
+ target_aspect_ratio = find_closest_aspect_ratio(
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
+ # print(target_aspect_ratio)
+ # calculate the target width and height
+ target_width = image_size * target_aspect_ratio[0]
+ target_height = image_size * target_aspect_ratio[1]
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
+
+ # resize the image
+
+ tensor_images = image.unsqueeze(1) # 添加一个维度作为单通道
+ # pdb.set_trace()
+ resized_images = F.interpolate(tensor_images, size=(target_height, target_width), mode='bilinear', align_corners=False) #(1792,1344)
+ resized_images = resized_images > 0
+ # print(resized_images.shape)
+ # 然后像 PIL 那样裁剪图像块
+ processed_images = []
+ for i in range(blocks):
+ top = (i // (target_width // image_size)) * image_size
+ left = (i % (target_width // image_size)) * image_size
+ bottom = top + image_size
+ right = left + image_size
+ # 使用 tensor 切片进行裁剪
+ split_img = resized_images[..., top:bottom, left:right] # 这里使用...来保持通道这一维度
+ processed_images.append(split_img)
+ # plt.imshow(split_img.sum(0).squeeze())
+ # plt.savefig(f'/workdir/guantongkun/12490719/eef5a3b245897c9f4335463fb12fed35/work_dirs/{i}_mask.jpg', dpi=600)
+ # pdb.set_trace()
+ # 最后,如果您需要,可以对处理过的图像list进行任何后续操作
+ # 例如,convert回通道为最后维度的形式,如果是单通道的话
+ processed_images = [img.squeeze(1) for img in processed_images]
+
+ assert len(processed_images) == blocks
+ if use_thumbnail and len(processed_images) != 1:
+ thumbnail_img = F.interpolate(tensor_images, size=(image_size, image_size), mode='bilinear', align_corners=False).squeeze(1)
+ thumbnail_img = thumbnail_img > 0
+ # Image.fromarray(thumbnail_img.cpu().numpy().astype(np.uint8))
+ processed_images.append(thumbnail_img)
+ return processed_images
diff --git a/internvl/train/dataset_packed.py b/internvl/train/dataset_packed.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b54ed47524b054ab54d737457626da785723fe7
--- /dev/null
+++ b/internvl/train/dataset_packed.py
@@ -0,0 +1,634 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import bisect
+import copy
+import logging
+from collections import defaultdict
+from typing import List, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.utils.data import IterableDataset, get_worker_info
+from transformers.trainer_pt_utils import LabelSmoother
+
+from .constants import IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN
+
+IGNORE_TOKEN_ID = LabelSmoother.ignore_index
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+class PackedDataset(IterableDataset):
+ def __init__(
+ self,
+ tokenizer,
+ data_rank,
+ data_world_size,
+ datasets: List,
+ dataset_weight: List[int] = None,
+ num_images_expected: int = 6,
+ max_packed_tokens: int = 32768,
+ max_buffer_size: int = 100,
+ log_freq: int = 1000000,
+ strict_mode: bool = False,
+ debug_mode: bool = False,
+ replacement: bool = True,
+ allow_overflow: bool = True,
+ allow_empty_data: bool = False,
+ allow_deduplicated_ds_name: bool = False,
+ ):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.data_rank = data_rank
+ self.data_world_size = data_world_size
+ self.datasets = datasets
+ self.num_images_expected = num_images_expected
+ self.max_buffer_size = max_buffer_size
+ self.log_freq = log_freq
+ self.strict_mode = strict_mode
+ self.debug_mode = debug_mode
+ self.replacement = replacement
+ self.allow_overflow = allow_overflow
+ self.allow_empty_data = allow_empty_data
+
+ self.max_packed_tokens = max_packed_tokens
+
+ self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN)
+ self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
+ self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN)
+
+ assert self.img_start_token_id != self.tokenizer.unk_token_id
+ assert self.img_token_id != self.tokenizer.unk_token_id
+ assert self.img_end_token_id != self.tokenizer.unk_token_id
+
+ if dataset_weight is None:
+ dataset_weight = [1] * len(datasets)
+ self.dataset_type = [d.dataset_type for d in self.datasets]
+
+ self.datasets_orig = datasets
+ self.dataset_weight_orig = [w / sum(dataset_weight) for w in dataset_weight]
+
+ self.datasets = [ds for ds in self.datasets_orig]
+ self.dataset_weight = [w for w in self.dataset_weight_orig]
+
+ # lazy init
+ self.worker_id = None
+ self.worker_state_key = None
+ self.dataset_iter_list = None
+ self._state_dict = {
+ 'sample_info': {d.ds_name:0 for d in self.datasets},
+ }
+
+ self.worker_custom_infos = None
+
+ ds_name_list = [d.ds_name for d in self.datasets]
+ if not allow_deduplicated_ds_name:
+ assert len(ds_name_list) == len(set(ds_name_list)), f'deduplicated ds_name: {ds_name_list}'
+
+ for ds in self.datasets:
+ if ds.max_num_images > self.num_images_expected:
+ logger.warning(f'{ds.max_num_images=} of {ds.ds_name} is larger than {self.num_images_expected=}')
+ ds.max_num_images = num_images_expected
+
+ if ds.max_tokens > self.max_packed_tokens:
+ logger.warning(f'{ds.max_tokens=} of {ds.ds_name} is larger than {self.max_packed_tokens=}')
+ ds.max_tokens = self.max_packed_tokens
+
+ self._state_dict[ds.ds_name] = {}
+
+ if get_rank() == 0:
+ logger.info(
+ f'Loaded dataset to pack: {ds_name_list}, '
+ f'{self.num_images_expected=}, {self.max_packed_tokens=}, '
+ f'{self.replacement=}, {self.allow_overflow=}',
+ )
+
+ temp = []
+ for ds, ds_w in zip(self.datasets, self.dataset_weight):
+ temp.append(f'{ds.ds_name:<25}: {ds_w*100:.2f}%')
+ temp = '\n'.join(temp)
+ logger.info(
+ f'Sampling prob for each dataset:\n{temp}'
+ )
+
+ if self.allow_empty_data:
+ logger.warning('allow_empty_data is enabled, note that empty data may be generated!')
+
+ def load_state_dict(self, state_dict, custom_infos=None):
+
+ self.worker_custom_infos = custom_infos
+
+ self._state_dict.update(state_dict)
+ for ds in self.datasets:
+ if ds.ds_name in self._state_dict:
+ ds.load_state_dict(self._state_dict[ds.ds_name])
+ logger.info(f'{ds.ds_name=} is resumed.')
+ else:
+ logger.warning(f'{ds.ds_name=} is not resumed.')
+
+ def _should_log(self):
+ worker_id = 0 if get_worker_info() is None else get_worker_info().id
+ num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers
+
+ worker_id = num_workers * get_rank() + worker_id
+ num_workers = num_workers * get_world_size()
+
+ return worker_id == 0
+
+ def next_data(self, current_dataset_idx):
+ while True:
+ try:
+ current_sample = next(self.dataset_iter_list[current_dataset_idx])
+ break # Exit loop if successful
+ except StopIteration:
+ if self.replacement:
+ # logger.info(f'[Worker id {self.worker_id}] Dataset {self.datasets[current_dataset_idx].ds_name} is exhausted, restart it.')
+ try:
+ self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx])
+ current_sample = next(self.dataset_iter_list[current_dataset_idx])
+ break
+ except:
+ # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}')
+ self.datasets.pop(current_dataset_idx)
+ self.dataset_iter_list.pop(current_dataset_idx)
+ self.dataset_weight.pop(current_dataset_idx)
+
+ if len(self.datasets) == 0:
+ raise StopIteration
+ current_dataset_idx = np.random.choice(len(self.datasets))
+ else:
+ # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}')
+ self.datasets.pop(current_dataset_idx)
+ self.dataset_iter_list.pop(current_dataset_idx)
+ self.dataset_weight.pop(current_dataset_idx)
+
+ if len(self.datasets) == 0:
+ raise StopIteration
+ current_dataset_idx = np.random.choice(len(self.datasets))
+ except:
+ logger.error('Unexpected error!')
+ if len(self.datasets) == 0:
+ raise StopIteration
+ current_dataset_idx = np.random.choice(len(self.datasets))
+
+ current_ds_name = self.datasets[current_dataset_idx].ds_name
+ current_sample['type_ids'] = torch.zeros_like(current_sample['input_ids']) + current_dataset_idx
+
+ if self.worker_state_key not in self._state_dict[current_ds_name]:
+ self._state_dict[current_ds_name][self.worker_state_key] = {}
+
+ meta_info = current_sample.pop('meta_info', {})
+ self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info)
+ self._state_dict['sample_info'][self.datasets[current_dataset_idx].ds_name] += 1
+ return current_sample
+
+ def find_buffer(self, buffer_list, new_sample):
+ # NOTE: use `bisect` to search might be faster
+
+ find = False
+ find_idx = -1
+ num_images_current = new_sample['pixel_values'].size(0)
+ for buffer_idx, buffer in enumerate(buffer_list):
+ num_images_buffer = buffer['pixel_values'].size(0)
+ if num_images_buffer + num_images_current <= self.num_images_expected:
+ num_merged_tokens = new_sample['input_ids'].size(0) + buffer['input_ids'].size(0)
+
+ if num_merged_tokens <= self.max_packed_tokens:
+ find = True
+ find_idx = buffer_idx
+ break
+
+ if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2:
+ find = True
+ find_idx = buffer_idx
+
+ if find:
+ return buffer_list.pop(find_idx)
+ return None
+
+ def update_buffer(self, buffer, new_sample):
+ if buffer is None:
+ new_sample['data_index'] = torch.zeros_like(new_sample['input_ids'])
+ return new_sample
+
+ new_sample['data_index'] = torch.ones_like(new_sample['input_ids']) + buffer['data_index'][-1].item()
+
+ assert buffer.keys() == new_sample.keys()
+ for k in buffer:
+ buffer[k] = torch.cat([buffer[k], new_sample[k]])
+ return buffer
+
+ @staticmethod
+ def check_valid(sample_to_check, min_active_tokens_ratio=1/256):
+ num_ignore_tokens = (sample_to_check['labels'] == IGNORE_TOKEN_ID).sum()
+ num_tokens = sample_to_check['labels'].numel()
+ return (1 - num_ignore_tokens / num_tokens) > min_active_tokens_ratio
+
+ @staticmethod
+ def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id):
+ if buffer['input_ids'].size(0) <= max_tokens:
+ return [buffer]
+
+ def _image_is_splitted(input_ids, cut_idx):
+ is_image_start = input_ids[cut_idx].item() == img_start_token_id
+ is_image_token = input_ids[cut_idx].item() == img_token_id
+ is_image_end = input_ids[cut_idx].item() == img_end_token_id
+ return is_image_start or is_image_token or is_image_end
+
+ def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx):
+ assert (right_idx is None) == (right_img_idx is None)
+
+ left_sample = {}
+ right_sample = {} if right_idx is not None else None
+ for k in sample_to_split:
+ if k in ['input_ids', 'labels', 'attention_mask', 'position_ids', 'data_index', 'type_ids']:
+ left_sample[k] = sample_to_split[k][:left_idx]
+ if right_sample is not None:
+ right_sample[k] = sample_to_split[k][right_idx:]
+ elif k in ['pixel_values', 'image_flags']:
+ left_sample[k] = sample_to_split[k][:left_img_idx]
+ if right_sample is not None:
+ right_sample[k] = sample_to_split[k][right_img_idx:]
+ else:
+ raise NotImplementedError(f'find unsupported keys: {k} from {sample_to_split.keys()}')
+ return left_sample, right_sample
+
+ splitted_buffer = []
+ while buffer['input_ids'].size(0) > max_tokens:
+ img_start_idx_list = (buffer['input_ids'] == img_start_token_id).nonzero().squeeze(1).tolist()
+ img_end_idx_list = (buffer['input_ids'] == img_end_token_id).nonzero().squeeze(1).tolist()
+ assert len(img_start_idx_list) == len(img_end_idx_list)
+
+ if _image_is_splitted(buffer['input_ids'], max_tokens):
+ cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens)
+ if buffer['input_ids'][max_tokens] == img_start_token_id:
+ assert max_tokens == img_start_idx_list[cut_idx]
+ cut_left_idx = img_start_idx_list[cut_idx]
+ cut_left_img_idx = cut_idx
+ else:
+ cut_left_idx = img_start_idx_list[cut_idx - 1]
+ cut_left_img_idx = cut_idx - 1
+ cut_right_idx = cut_left_idx
+ cut_right_img_idx = cut_left_img_idx
+ else:
+ cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens)
+ if cut_img_idx < len(img_start_idx_list):
+ cut_right_idx = img_start_idx_list[cut_img_idx]
+ cut_right_img_idx = cut_img_idx
+ else:
+ cut_right_idx = None
+ cut_right_img_idx = None
+
+ cut_left_idx = max_tokens
+ cut_left_img_idx = cut_right_img_idx if cut_right_img_idx is not None else buffer['pixel_values'].size(0)
+
+ left, right = _split(
+ sample_to_split=buffer,
+ left_idx=cut_left_idx,
+ left_img_idx=cut_left_img_idx,
+ right_idx=cut_right_idx,
+ right_img_idx=cut_right_img_idx,
+ )
+
+ assert (left['input_ids'] == img_end_token_id).sum() == (left['input_ids'] == img_start_token_id).sum() == left['pixel_values'].size(0)
+ if right is not None:
+ assert (right['input_ids'] == img_end_token_id).sum() == (right['input_ids'] == img_start_token_id).sum() == right['pixel_values'].size(0)
+
+ if left['pixel_values'].size(0) >= 1 and PackedDataset.check_valid(left):
+ splitted_buffer.append(left)
+
+ if right is None or right['pixel_values'].size(0) == 0:
+ break
+
+ buffer = right
+ if buffer['input_ids'].size(0) <= max_tokens and PackedDataset.check_valid(buffer):
+ splitted_buffer.append(buffer)
+ break
+
+ logger.debug(
+ f'split a sample into {len(splitted_buffer)} samples, '
+ f'current max_tokens={max_tokens}'
+ )
+ return splitted_buffer
+
+ def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer):
+ # NOTE: in-place operation
+
+ splitted_buffer = PackedDataset.split_buffer(
+ buffer=buffer,
+ max_tokens=self.max_packed_tokens,
+ img_start_token_id=self.img_start_token_id,
+ img_token_id=self.img_token_id,
+ img_end_token_id=self.img_end_token_id,
+ )
+
+ for each_buffer in splitted_buffer:
+ if each_buffer['pixel_values'].size(0) > self.num_images_expected:
+ logger.error(
+ f"Find a sample with {each_buffer['pixel_values'].size(0)} images, "
+ f'which exceeds {self.num_images_expected}'
+ )
+ continue
+
+ if each_buffer['input_ids'].size(0) >= self.max_packed_tokens:
+ assert each_buffer['input_ids'].size(0) == self.max_packed_tokens
+ buffer_max_len_list.append(each_buffer)
+ continue
+
+ find_idx = len(buffer_list)
+ num_images_new_sample = each_buffer['pixel_values'].size(0)
+ for buffer_idx in range(len(buffer_list)):
+ if buffer_list[buffer_idx]['pixel_values'].size(0) < num_images_new_sample:
+ find_idx = buffer_idx
+ break
+ buffer_list.insert(find_idx, each_buffer)
+
+ for i in range(1, len(buffer_list)):
+ assert buffer_list[i-1]['pixel_values'].size(0) >= buffer_list[i]['pixel_values'].size(0)
+
+ return buffer_list, buffer_max_len_list
+
+ def pad_buffer(self, buffer):
+ if buffer['pixel_values'].size(0) == self.num_images_expected:
+ return buffer
+
+ num_pad_images = self.num_images_expected - buffer['pixel_values'].size(0)
+ pad_images = torch.stack([
+ torch.zeros_like(buffer['pixel_values'][0])
+ for _ in range(num_pad_images)
+ ])
+ pad_image_flags = torch.tensor([0] * num_pad_images, dtype=torch.long)
+
+ buffer['pixel_values'] = torch.cat([buffer['pixel_values'], pad_images])
+ buffer['image_flags'] = torch.cat([buffer['image_flags'], pad_image_flags])
+
+ return buffer
+
+ def postprocess_buffer(self, buffer, custom_infos=None):
+ buffer['worker_state_key'] = self.worker_state_key
+ buffer['worker_state_dict'] = self._state_dict
+ if custom_infos is not None:
+ buffer['custom_infos'] = {self.worker_state_key: copy.deepcopy(custom_infos)}
+ return buffer
+
+ def print_log(self, iter_idx, buffer_list):
+ if iter_idx % self.log_freq != 0:
+ return
+
+ if self._should_log():
+ logger.info(
+ f"{iter_idx=}, {len(buffer_list)=}, {self._state_dict['sample_info']}"
+ )
+
+ def __iter__(self):
+ iter_idx = 0
+ buffer_list = []
+ buffer_max_len_list = []
+
+ if self._should_log():
+ logger.info(f'Begin to iter, {len(buffer_list)=}')
+
+ worker_id = 0 if get_worker_info() is None else get_worker_info().id
+ num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers
+
+ worker_id = num_workers * self.data_rank + worker_id
+ num_workers = num_workers * self.data_world_size
+
+ rng = np.random.default_rng(seed=worker_id)
+
+ # reset states of each dataset
+ self.worker_id = worker_id
+ self.worker_state_key = f'work_state_{self.worker_id}'
+ self.datasets = [d for d in self.datasets_orig]
+ self.dataset_weight = [w for w in self.dataset_weight_orig]
+ self.dataset_iter_list = [iter(d) for d in self.datasets]
+
+ for ds in self.datasets:
+ # if not isinstance(ds, (ImageTextPairDataset, InterleavedDataset)):
+ ds.worker_id = worker_id
+ ds.worker_state_key = f'work_state_{self.worker_id}'
+ ds.num_workers = num_workers
+ if self._should_log() and worker_id == 0:
+ logger.info(f'set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}')
+
+ if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos:
+ custom_infos = self.worker_custom_infos[self.worker_state_key]
+ # buffer list
+ if 'buffer_list' in custom_infos and isinstance(custom_infos['buffer_list'], list):
+ buffer_list = custom_infos['buffer_list']
+ if self._should_log() and worker_id == 0:
+ logger.info(f'[{self.worker_state_key}] load buffer list --> {len(buffer_list)=}')
+ # other infos
+
+ # reset
+ self.worker_custom_infos = None
+
+ logger.debug(
+ f'{self.__class__.__name__} Rank {self.data_rank} '
+ f'Worker {worker_id} begin to load data'
+ )
+
+ while True:
+ self.dataset_weight = [w / sum(self.dataset_weight) for w in self.dataset_weight]
+ current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight)
+
+ try:
+ current_sample = self.next_data(current_dataset_idx)
+ except:
+ logger.info(f'All datasets are exhausted, begin to empty the buffer_list ({len(buffer_list)=})')
+ while len(buffer_list) > 0:
+ if self.strict_mode:
+ yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)))
+ else:
+ yield self.postprocess_buffer(buffer_list.pop(0))
+ logger.info(f'buffer_list is empty! ({len(buffer_list)=})')
+ return
+
+ buffer = self.find_buffer(buffer_list, current_sample)
+ buffer = self.update_buffer(buffer, current_sample)
+ buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer)
+
+ while len(buffer_max_len_list) > 0:
+ if buffer_max_len_list[0]['pixel_values'].size(0) != self.max_packed_tokens:
+ logger.debug(
+ f'num tokens of a buffer exceed {self.max_packed_tokens=}, '
+ f"yield a sample with {buffer_max_len_list[0]['pixel_values'].size(0)} images"
+ )
+ if self.strict_mode and buffer_max_len_list[0]['pixel_values'].size(0) != self.num_images_expected:
+ # buffer_max_len_list.pop(0)
+ yield self.postprocess_buffer(self.pad_buffer(buffer_max_len_list.pop(0)), {'buffer_list': buffer_list})
+ else:
+ yield self.postprocess_buffer(buffer_max_len_list.pop(0), {'buffer_list': buffer_list})
+
+ while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) > self.num_images_expected:
+ logger.error(
+ f"num images of a buffer ({buffer_list[0]['pixel_values'].size(0)}) "
+ f'is larger than num_images_expected({self.num_images_expected})'
+ )
+ buffer_list.pop(0)
+
+ while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) == self.num_images_expected:
+ if self.debug_mode:
+ debug_data = self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
+ while True:
+ yield debug_data.copy()
+
+ yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
+
+ while len(buffer_list) > self.max_buffer_size:
+ logger.debug(
+ f'Failed to pack data to exactly {self.num_images_expected} images, '
+ f"yield a data sample with {buffer_list[0]['pixel_values'].size(0)} images."
+ )
+ if self.strict_mode:
+ yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {'buffer_list': buffer_list})
+ else:
+ yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list})
+
+ self.print_log(iter_idx=iter_idx, buffer_list=buffer_list)
+ iter_idx += 1
+
+ @staticmethod
+ def get_cu_seqlens_and_indexes(
+ data_index: torch.LongTensor, # (seq_len,)
+ input_ids: torch.LongTensor, # (seq_len,)
+ labels: torch.LongTensor, # (seq_len,)
+ len2weight: callable,
+ ):
+ indexes = []
+ cu_seqlens = [0]
+ loss_weight = []
+
+ start = data_index.min()
+ end = data_index.max() + 1
+ for i in range(start, end):
+ num_tokens = (data_index == i).sum().item()
+ indexes.extend(list(range(num_tokens)))
+ cu_seqlens.append(cu_seqlens[-1] + num_tokens)
+ assert num_tokens > 0
+
+ curr_data_index = data_index[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens]
+ assert (curr_data_index == i).all(), data_index
+
+ curr_labels = labels[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens]
+ num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item()
+ loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens)
+
+ assert len(indexes) == data_index.size(0), f'{len(indexes)=}, {data_index.size(0)=}'
+
+ loss_weight = torch.tensor(loss_weight, dtype=torch.float32)
+ return cu_seqlens, indexes, loss_weight
+
+
+WARNING_CNT = defaultdict(int)
+
+
+def packed_collate_fn(
+ features,
+ data_collator,
+ len2weight: callable,
+ max_item_length: int,
+ micro_num: int = 1,
+ loss_reduction_all_gather: bool = False,
+ pad_id: int = 0,
+):
+ if not isinstance(features, list):
+ features = [features]
+
+ if len(features) > micro_num:
+ raise NotImplementedError(f'{len(features)=} > {micro_num=}')
+
+ if len(features) < micro_num and WARNING_CNT['micro_num_warning'] < 5:
+ logger.warning(
+ f'{len(features)=} > {micro_num=}, '
+ f'the features will be padded to satisfy micro_num requirement'
+ )
+ WARNING_CNT['micro_num_warning'] += 1
+
+ # ensure that the len(features) is equal to the required micro_num
+ num_features = len(features)
+ while len(features) < micro_num:
+ features.append(copy.deepcopy(features[0]))
+ features[-1]['labels'] = torch.full_like(features[-1]['labels'], IGNORE_TOKEN_ID)
+
+ indexes = []
+ cu_seqlens = []
+ cu_num_images_list = [0]
+
+ worker_state_key_list = []
+ worker_state_dict_list = []
+ worker_state_custom_infos_list = []
+
+ batch_lens = [feat['input_ids'].shape for feat in features]
+ max_item_length = max_item_length or max(batch_lens)[0]
+
+ num_samples = 0
+ num_padding_tokens = 0
+ for feat_idx, feat in enumerate(features):
+ data_index = feat.pop('data_index')
+ curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes(
+ data_index=data_index,
+ input_ids=feat['input_ids'],
+ labels=feat['labels'],
+ len2weight=len2weight,
+ )
+
+ feat['loss_weight'] = curr_loss_weight
+
+ if feat_idx < num_features:
+ num_samples += len(curr_cu_seqlens) - 1
+
+ if curr_cu_seqlens[-1] < max_item_length:
+ curr_cu_seqlens.append(max_item_length)
+ curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2])))
+
+ indexes.append(torch.tensor(curr_indexes, dtype=torch.long))
+ cu_seqlens.append(torch.tensor(curr_cu_seqlens, dtype=torch.int32))
+
+ worker_state_key_list.append(feat.pop('worker_state_key'))
+ worker_state_dict_list.append(feat.pop('worker_state_dict'))
+ worker_state_custom_infos_list.append(feat.pop('custom_infos', None))
+
+ num_padding_tokens += (max_item_length - feat['input_ids'].size(0))
+ cu_num_images_list.append(cu_num_images_list[-1] + feat['pixel_values'].size(0))
+
+ batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id)
+ # convert it to list in case it is converted into bf16
+ batch['loss_weight'] = torch.where(batch['labels'] == IGNORE_TOKEN_ID, 0, batch['loss_weight']).tolist()
+ batch['attention_mask'] = torch.stack(cu_seqlens)
+ batch['loss_reduction_all_gather'] = loss_reduction_all_gather
+ batch['statistics'] = torch.tensor(
+ [
+ num_samples,
+ num_padding_tokens,
+ batch['image_flags'].numel() - batch['image_flags'].sum().item(),
+ ],
+ dtype=torch.long,
+ )
+ batch.pop('type_ids')
+ return batch
diff --git a/internvl/train/trainer_dpo.py b/internvl/train/trainer_dpo.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7a3ed624a43900f902a9dd48b7182b8a21bf6d2
--- /dev/null
+++ b/internvl/train/trainer_dpo.py
@@ -0,0 +1,287 @@
+# --------------------------------------------------------
+# InternVL
+# Copyright (c) 2024 OpenGVLab
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+from typing import Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.utils.data import ConcatDataset
+from trl import DPOTrainer
+from trl.trainer.utils import RunningMoments, pad_to_length
+
+
+def _map(self, *args, **kwargs):
+ return self
+
+
+ConcatDataset.map = _map
+
+
+class MultimodalDPOTrainer(DPOTrainer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ if self.loss_type != 'bco_pair' and 'bco_pair' in self.loss_type:
+ self.running = RunningMoments(self.accelerator)
+
+ @staticmethod
+ def concatenated_inputs(
+ batch: Dict[str, Union[List, torch.LongTensor]],
+ is_encoder_decoder: bool = False,
+ is_vision_model: bool = False,
+ label_pad_token_id: int = -100,
+ padding_value: int = 0,
+ device: Optional[torch.device] = None,
+ ) -> Dict[str, torch.LongTensor]:
+ """Concatenate the chosen and rejected inputs into a single tensor.
+
+ Args:
+ batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
+ is_encoder_decoder: Whether the model is an encoder-decoder model.
+ label_pad_token_id: The label pad token id.
+ padding_value: The padding value to use for the concatenated inputs_ids.
+ device: The device for the concatenated inputs.
+
+ Returns:
+ A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
+ """
+ concatenated_batch = {}
+
+ if is_encoder_decoder:
+ max_length = max(batch['chosen_labels'].shape[1], batch['rejected_labels'].shape[1])
+ else:
+ max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1])
+
+ for k in batch:
+ if k.startswith('chosen') and isinstance(batch[k], torch.Tensor):
+ if 'labels' in k or is_encoder_decoder:
+ pad_value = label_pad_token_id
+ elif k.endswith('_input_ids'):
+ pad_value = padding_value
+ elif k.endswith('_attention_mask'):
+ pad_value = 0
+ concatenated_key = k.replace('chosen', 'concatenated')
+ concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
+ for k in batch:
+ if k.startswith('rejected') and isinstance(batch[k], torch.Tensor):
+ if 'labels' in k or is_encoder_decoder:
+ pad_value = label_pad_token_id
+ elif k.endswith('_input_ids'):
+ pad_value = padding_value
+ elif k.endswith('_attention_mask'):
+ pad_value = 0
+ concatenated_key = k.replace('rejected', 'concatenated')
+ concatenated_batch[concatenated_key] = torch.cat(
+ (
+ concatenated_batch[concatenated_key],
+ pad_to_length(batch[k], max_length, pad_value=pad_value),
+ ),
+ dim=0,
+ ).to(device=device)
+
+ if is_encoder_decoder:
+ concatenated_batch['concatenated_input_ids'] = batch['prompt_input_ids'].repeat(2, 1).to(device=device)
+ concatenated_batch['concatenated_attention_mask'] = (
+ batch['prompt_attention_mask'].repeat(2, 1).to(device=device)
+ )
+
+ if 'pixel_values' in batch:
+ concatenated_batch['pixel_values'] = batch['pixel_values'].repeat(2, 1, 1, 1)
+ concatenated_batch['image_flags'] = batch['image_flags'].repeat(2)
+
+ return concatenated_batch
+
+ def concatenated_forward(
+ self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
+ """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
+
+ We do this to avoid doing two forward passes, because it's faster for FSDP.
+ """
+ concatenated_batch = self.concatenated_inputs(
+ batch,
+ is_encoder_decoder=self.is_encoder_decoder,
+ is_vision_model=self.is_vision_model,
+ label_pad_token_id=self.label_pad_token_id,
+ padding_value=self.padding_value,
+ device=self.accelerator.device,
+ )
+ len_chosen = batch['chosen_labels'].shape[0]
+
+ model_kwargs = {}
+
+ if self.is_encoder_decoder:
+ model_kwargs['labels'] = concatenated_batch['concatenated_labels']
+ model_kwargs['decoder_input_ids'] = concatenated_batch.pop('concatenated_decoder_input_ids', None)
+
+ if self.is_vision_model:
+ model_kwargs['pixel_values'] = concatenated_batch['pixel_values']
+ model_kwargs['pixel_attention_mask'] = concatenated_batch['pixel_attention_mask']
+
+ if self.aux_loss_enabled:
+ model_kwargs['output_router_logits'] = True
+
+ outputs = model(
+ input_ids=concatenated_batch['concatenated_input_ids'],
+ attention_mask=concatenated_batch['concatenated_attention_mask'],
+ pixel_values=concatenated_batch['pixel_values'],
+ image_flags=concatenated_batch['image_flags'],
+ use_cache=False,
+ **model_kwargs,
+ )
+ all_logits = outputs.logits
+
+ all_logps, size_completion = self.get_batch_logps(
+ all_logits,
+ concatenated_batch['concatenated_labels'],
+ # average_log_prob=self.loss_type == "ipo",
+ is_encoder_decoder=self.is_encoder_decoder,
+ label_pad_token_id=self.label_pad_token_id,
+ )
+
+ def cross_entropy_loss(logits, labels):
+ if not self.is_encoder_decoder:
+ # Shift so that tokens < n predict n
+ logits = logits[..., :-1, :].contiguous()
+ labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+ logits = logits.view(-1, logits.shape[-1])
+ labels = labels.view(-1)
+ # Enable model parallelism
+ labels = labels.to(logits.device)
+ loss = loss_fct(logits, labels)
+ return loss
+
+ labels = concatenated_batch['concatenated_labels'].clone()
+ nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
+
+ if self.loss_type == 'ipo':
+ all_logps = all_logps / size_completion
+
+ chosen_logps = all_logps[:len_chosen]
+ rejected_logps = all_logps[len_chosen:]
+
+ chosen_logits = all_logits[:len_chosen]
+ rejected_logits = all_logits[len_chosen:]
+
+ if self.aux_loss_enabled:
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss)
+
+ return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
+
+ def _prepare_deepspeed(self, model):
+ deepspeed_plugin = self.accelerator.state.deepspeed_plugin
+ config_kwargs = deepspeed_plugin.deepspeed_config
+ if config_kwargs['zero_optimization']['stage'] == 3:
+ print('Enable DPOTrainer._prepare_deepspeed')
+ return super()._prepare_deepspeed(model)
+
+ print('Disable DPOTrainer._prepare_deepspeed')
+ for param in model.parameters():
+ param.requires_grad = False
+
+ model.eval()
+ model = model.to(self.accelerator.device)
+ return model
+
+ def get_batch_loss_metrics(
+ self,
+ model,
+ batch: Dict[str, Union[List, torch.LongTensor]],
+ train_eval: Literal['train', 'eval'] = 'train',
+ ):
+ """Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
+ metrics = {}
+
+ forward_output = self.concatenated_forward(model, batch)
+ (
+ policy_chosen_logps,
+ policy_rejected_logps,
+ policy_chosen_logits,
+ policy_rejected_logits,
+ policy_nll_loss,
+ ) = forward_output[:5]
+ if self.aux_loss_enabled:
+ aux_loss = forward_output[5]
+
+ # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
+ if (
+ 'reference_chosen_logps' in batch
+ and 'reference_rejected_logps' in batch
+ and self.args.rpo_alpha is not None
+ ):
+ reference_chosen_logps = batch['reference_chosen_logps']
+ reference_rejected_logps = batch['reference_rejected_logps']
+ else:
+ with torch.no_grad():
+ if self.ref_model is None:
+ with self.null_ref_context():
+ (
+ reference_chosen_logps,
+ reference_rejected_logps,
+ _,
+ _,
+ _,
+ ) = self.concatenated_forward(self.model, batch)
+ else:
+ (
+ reference_chosen_logps,
+ reference_rejected_logps,
+ _,
+ _,
+ _,
+ ) = self.concatenated_forward(self.ref_model, batch)
+
+ if ',' in self.loss_type:
+ loss_type = self.loss_type
+ loss_type_list = loss_type.split(',')
+
+ losses, chosen_rewards, rejected_rewards = 0, 0, 0
+ for curr_type in loss_type_list:
+ self.loss_type = curr_type
+ curr_losses, curr_chosen_rewards, curr_rejected_rewards = self.dpo_loss(
+ policy_chosen_logps,
+ policy_rejected_logps,
+ reference_chosen_logps,
+ reference_rejected_logps,
+ )
+ curr_weight = getattr(self.args, f'{curr_type}_loss_weight')
+ losses = losses + curr_losses * curr_weight
+ chosen_rewards = chosen_rewards + curr_chosen_rewards * curr_weight
+ rejected_rewards = rejected_rewards + curr_rejected_rewards * curr_weight
+
+ self.loss_type = loss_type
+ else:
+ losses, chosen_rewards, rejected_rewards = self.dpo_loss(
+ policy_chosen_logps,
+ policy_rejected_logps,
+ reference_chosen_logps,
+ reference_rejected_logps,
+ )
+
+ reward_accuracies = (chosen_rewards > rejected_rewards).float()
+
+ if self.args.rpo_alpha is not None:
+ # losses = losses * self.args.rpo_alpha + policy_nll_loss
+ losses = losses + policy_nll_loss * self.args.rpo_alpha
+
+ prefix = 'eval_' if train_eval == 'eval' else ''
+ metrics[f'{prefix}rewards/chosen'] = chosen_rewards.mean().cpu()
+ metrics[f'{prefix}rewards/rejected'] = rejected_rewards.mean().cpu()
+ metrics[f'{prefix}rewards/accuracies'] = reward_accuracies.mean().cpu()
+ metrics[f'{prefix}rewards/margins'] = (chosen_rewards - rejected_rewards).mean().cpu()
+ metrics[f'{prefix}logps/rejected'] = policy_rejected_logps.detach().mean().cpu()
+ metrics[f'{prefix}logps/chosen'] = policy_chosen_logps.detach().mean().cpu()
+ metrics[f'{prefix}logits/rejected'] = policy_rejected_logits.detach().mean().cpu()
+ metrics[f'{prefix}logits/chosen'] = policy_chosen_logits.detach().mean().cpu()
+ if self.args.rpo_alpha is not None:
+ metrics[f'{prefix}nll_loss'] = policy_nll_loss.detach().mean().cpu()
+
+ if self.aux_loss_enabled:
+ return losses.mean() + getattr(model.config, 'router_aux_loss_coef', 0.0) * aux_loss, metrics
+
+ return losses.mean(), metrics
diff --git a/internvl/train/trainer_monkey_patch.py b/internvl/train/trainer_monkey_patch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5c501c27fc17404b9ba992cf578f7a475e3abe3
--- /dev/null
+++ b/internvl/train/trainer_monkey_patch.py
@@ -0,0 +1,246 @@
+import json
+import os
+
+import torch
+import torch.nn as nn
+import transformers
+from transformers import Trainer, logging
+from transformers.trainer import is_sagemaker_mp_enabled
+
+logger = logging.get_logger(__name__)
+
+
+def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer):
+ if var_name.startswith('internvl.'):
+ var_name = var_name[len('internvl.'):]
+ if var_name in ('query_tokens', 'logit_scale',):
+ return 0
+ if var_name.startswith('clip_projector.'):
+ return vit_num_max_layer
+ if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \
+ var_name == 'text_projection':
+ return llama_num_max_layer
+ if var_name.startswith('vision_model.'):
+ if 'embeddings.' in var_name:
+ return 0
+ if 'layers.' in var_name:
+ var_name = var_name.split('layers.')[-1]
+ layer_id = int(var_name.split('.')[0])
+ return layer_id + 1
+ if var_name.startswith('qllama.'):
+ if 'embed_tokens' in var_name:
+ return 0
+ if 'layers.' in var_name:
+ var_name = var_name.split('layers.')[-1]
+ layer_id = int(var_name.split('.')[0])
+ return layer_id + 1
+ else:
+ return llama_num_max_layer
+ return 0
+
+
+def param_classification(name):
+ if name.startswith('internvl.'):
+ name = name[len('internvl.'):]
+ if name in ['query_tokens', 'text_projection', 'logit_scale']:
+ return 'qllama'
+ elif name.startswith('vision_model.'):
+ return 'vit'
+ elif name.startswith('qllama.'):
+ return 'qllama'
+ elif name.startswith('clip_projector.'):
+ return 'vit'
+ elif name.startswith('clip_projector2.'):
+ return 'qllama'
+ elif name.startswith('itm_head.'):
+ return 'qllama'
+ else:
+ return 'other'
+
+
+def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ # import pdb; pdb.set_trace()
+ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+
+ parameter_groups = {}
+ try: # for stage2 model
+ vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2
+ qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2
+ except: # for stage3 model
+ vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2
+ qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2
+ print('vit_num_layers:', vit_num_layers)
+ print('qllama_num_layers:', qllama_num_layers)
+
+ vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0))
+ qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0))
+ qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0))
+ print('vit_layer_decay_rate:', vit_layer_decay_rate)
+ print('qllama_layer_decay_rate:', qllama_layer_decay_rate)
+ print('qllama_lr_scale:', qllama_lr_scale)
+
+ for name, param in opt_model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith('.bias'):
+ group_name = 'no_decay'
+ this_weight_decay = 0.
+ else:
+ group_name = 'decay'
+ this_weight_decay = self.args.weight_decay
+
+ cls = param_classification(name)
+ layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers)
+ group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name)
+ if group_name not in parameter_groups:
+ if cls == 'vit':
+ scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1)
+ elif cls == 'qllama':
+ scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1)
+ scale = scale * qllama_lr_scale
+ else:
+ scale = 1.0
+ scale = min(1.0, scale)
+ parameter_groups[group_name] = {
+ 'weight_decay': this_weight_decay,
+ 'params': [],
+ 'param_names': [],
+ 'lr_scale': scale,
+ 'group_name': group_name,
+ 'lr': scale * self.args.learning_rate,
+ }
+ parameter_groups[group_name]['params'].append(param)
+ parameter_groups[group_name]['param_names'].append(name)
+
+ rank = torch.distributed.get_rank()
+ if rank == 0:
+ to_display = {}
+ for key in parameter_groups:
+ to_display[key] = {
+ 'param_names': parameter_groups[key]['param_names'],
+ 'lr_scale': parameter_groups[key]['lr_scale'],
+ 'lr': parameter_groups[key]['lr'],
+ 'weight_decay': parameter_groups[key]['weight_decay'],
+ }
+ print('Param groups = %s' % json.dumps(to_display, indent=2))
+
+ optimizer_grouped_parameters = list(parameter_groups.values())
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == 'Adam8bit':
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
+ manager.register_module_override(module, 'weight', {'optim_bits': 32})
+ logger.debug(f'bitsandbytes: will optimize {module} in fp32')
+ logger.info(f'skipped: {skipped / 2 ** 20}M params')
+
+ if is_sagemaker_mp_enabled():
+ import smdistributed.modelparallel.torch as smp
+ self.optimizer = smp.DistributedOptimizer(self.optimizer)
+
+ return self.optimizer
+
+def create_optimizer_custom(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ # import pdb; pdb.set_trace()
+ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+
+ parameter_groups = {}
+
+ for name, param in opt_model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith('.bias'):
+ group_name = 'no_decay'
+ this_weight_decay = 0.
+ else:
+ group_name = 'decay'
+ this_weight_decay = self.args.weight_decay
+
+ if 'ocr_mlp' in name or 'upsample' in name:
+ group_name = '%s_%s' % ('modify', group_name)
+ elif 'vision_model' in name:
+ group_name = '%s_%s' % ('vit', group_name)
+ else:
+ group_name = '%s_%s' % ('base', group_name)
+
+ if group_name not in parameter_groups:
+ if 'ocr_mlp' in name or 'upsample' in name:
+ scale = 1.0
+ elif 'vision_model' in name:
+ scale = 0.05
+ else:
+ scale = 1.0
+
+ parameter_groups[group_name] = {
+ 'weight_decay': this_weight_decay,
+ 'params': [],
+ 'param_names': [],
+ 'lr_scale': scale,
+ 'group_name': group_name,
+ 'lr': scale * self.args.learning_rate,
+ }
+ parameter_groups[group_name]['params'].append(param)
+ parameter_groups[group_name]['param_names'].append(name)
+
+ rank = torch.distributed.get_rank()
+ if rank == 0:
+ to_display = {}
+ for key in parameter_groups:
+ to_display[key] = {
+ 'param_names': parameter_groups[key]['param_names'],
+ 'lr_scale': parameter_groups[key]['lr_scale'],
+ 'lr': parameter_groups[key]['lr'],
+ 'weight_decay': parameter_groups[key]['weight_decay'],
+ }
+ print('Param groups = %s' % json.dumps(to_display, indent=2))
+
+
+ optimizer_grouped_parameters = list(parameter_groups.values())
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == 'Adam8bit':
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
+ manager.register_module_override(module, 'weight', {'optim_bits': 32})
+ logger.debug(f'bitsandbytes: will optimize {module} in fp32')
+ logger.info(f'skipped: {skipped / 2 ** 20}M params')
+
+ if is_sagemaker_mp_enabled():
+ import smdistributed.modelparallel.torch as smp
+ self.optimizer = smp.DistributedOptimizer(self.optimizer)
+
+ return self.optimizer
+
+
+def replace_create_optimizer():
+ print('Replace original create_optimizer with custom create_optimizer')
+ # transformers.Trainer.create_optimizer = create_optimizer
+ transformers.Trainer.create_optimizer = create_optimizer_custom
diff --git a/resnet50/__init__.py b/resnet50/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5455d84f91037c842c9184e9ca64720aadf6dcd2
--- /dev/null
+++ b/resnet50/__init__.py
@@ -0,0 +1,5 @@
+from .model import build
+
+
+def build_model(args):
+ return build(args)
\ No newline at end of file
diff --git a/resnet50/__pycache__/__init__.cpython-310.pyc b/resnet50/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..132654c9f4bcd31af1a0213144d4c708e241f413
Binary files /dev/null and b/resnet50/__pycache__/__init__.cpython-310.pyc differ
diff --git a/resnet50/__pycache__/backbone.cpython-310.pyc b/resnet50/__pycache__/backbone.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6039b96ba94844ca55c948363a0fae273c108aec
Binary files /dev/null and b/resnet50/__pycache__/backbone.cpython-310.pyc differ
diff --git a/resnet50/__pycache__/model.cpython-310.pyc b/resnet50/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..186fec8bfdb6980609fe1a8bea27c0d883834954
Binary files /dev/null and b/resnet50/__pycache__/model.cpython-310.pyc differ
diff --git a/resnet50/backbone.py b/resnet50/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..63363e6f1c3db58644ae8120d62b5a9a5585ae27
--- /dev/null
+++ b/resnet50/backbone.py
@@ -0,0 +1,85 @@
+from collections import OrderedDict
+
+import os
+import torch
+import torch.nn.functional as F
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+from typing import Dict, List
+
+class FrozenBatchNorm2d(torch.nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
+ without which any other models than torchvision.models.resnet[18,34,50,101]
+ produce nans.
+ """
+
+ def __init__(self, n):
+ super(FrozenBatchNorm2d, self).__init__()
+ self.register_buffer("weight", torch.ones(n))
+ self.register_buffer("bias", torch.zeros(n))
+ self.register_buffer("running_mean", torch.zeros(n))
+ self.register_buffer("running_var", torch.ones(n))
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
+ if num_batches_tracked_key in state_dict:
+ del state_dict[num_batches_tracked_key]
+
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+
+ def forward(self, x):
+ # move reshapes to the beginning
+ # to make it fuser-friendly
+ w = self.weight.reshape(1, -1, 1, 1)
+ b = self.bias.reshape(1, -1, 1, 1)
+ rv = self.running_var.reshape(1, -1, 1, 1)
+ rm = self.running_mean.reshape(1, -1, 1, 1)
+ eps = 1e-5
+ scale = w * (rv + eps).rsqrt()
+ bias = b - rm * scale
+ return x * scale + bias
+
+
+class BackboneBase(nn.Module):
+
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
+ super().__init__()
+ for name, parameter in backbone.named_parameters():
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+ parameter.requires_grad_(False)
+ if return_interm_layers:
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
+ else:
+ return_layers = {'layer4': "0"}
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+ self.num_channels = num_channels
+
+ def forward(self, tensor_list):
+ xs = self.body(tensor_list)
+ return xs
+
+class Backbone(BackboneBase):
+ """ResNet backbone with frozen BatchNorm."""
+ def __init__(self, name: str,
+ train_backbone: bool,
+ return_interm_layers=False,
+ dilation=False):
+ backbone = getattr(torchvision.models, name)(
+ replace_stride_with_dilation=[False, False, dilation],
+ pretrained=False, norm_layer=nn.BatchNorm2d)
+
+ num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
+
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
+
+
+def build_backbone(args):
+ backbone = Backbone('resnet50', train_backbone=True)
+ return backbone
diff --git a/resnet50/model.py b/resnet50/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4309fcaf6290131df4b35e1deaf42760c49fe90b
--- /dev/null
+++ b/resnet50/model.py
@@ -0,0 +1,164 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+from .backbone import build_backbone
+import pdb
+import numpy as np
+from typing import Optional
+
+class TokenOCR(nn.Module):
+ def __init__(self, backbone):
+ """ Initializes the model.
+ Parameters:
+ backbone: torch module of the backbone to be used. See backbone.py
+ transformer: torch module of the transformer architecture. See transformer.py
+ num_classes: number of object classes
+
+ """
+ super().__init__()
+ self.language_embedding = nn.Embedding(92553, 2048, padding_idx=2)
+ for p in self.parameters():
+ p.requires_grad = False
+
+ self.backbone = backbone
+ init_tau=np.log(10)
+ init_b=-2.71
+ # self.t_prime = nn.Parameter(torch.ones([]) * init_tau)
+ # self.b = nn.Parameter(torch.ones([]) * init_b)
+ self.kb = True
+ self.upsample = nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels=2048,
+ out_channels=512,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ bias=False
+ ),
+ nn.SyncBatchNorm(512),
+ nn.ConvTranspose2d(
+ in_channels=512,
+ out_channels=512,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ bias=False
+ ),
+ nn.SyncBatchNorm(512),
+ )
+ self.ocr_mlp = nn.Sequential(
+ nn.Linear(512, 2048),
+ nn.GELU(),
+ nn.Linear(2048, 2048)
+ )
+
+ def forward(self,
+ pixel_values: torch.FloatTensor,
+ input_ids: torch.LongTensor = None,
+ image_flags: Optional[torch.LongTensor] = None,
+ mask_values: Optional[torch.LongTensor] = None,
+ masks_flags: Optional[torch.LongTensor] = None,
+ mask_nums: Optional[torch.LongTensor] = None,
+ ):
+ image_flags = image_flags.squeeze(-1)
+ try:
+ input_embeds = self.language_embedding(input_ids).clone()
+ except:
+ print('error'*1000)
+ import pdb; pdb.set_trace()
+ # import pdb; pdb.set_trace()
+ vit_embeds, vit_embeds_shape = self.extract_feature_custom(pixel_values) #(vit_batch_size, 16*16, 2048)
+ nb, nl, nd = vit_embeds.shape
+ h, w = vit_embeds_shape
+ vit_embeds = vit_embeds.reshape(nb, h, w, nd)
+ vit_embeds = vit_embeds.split(list(image_flags)) #[(vit_batch_size / B, h, w, C)]*B
+ vit_batch_size = pixel_values.shape[0]
+
+ B, N, C = input_embeds.shape
+ try:
+ assert sum(image_flags) == mask_values.shape[0]
+ except:
+ print((mask_values.shape, image_flags, mask_nums))
+
+ mask_values = torch.nn.functional.interpolate(mask_values.float(), size=(h, w), mode='bilinear', align_corners=False) #(128, 128)
+ masks = mask_values.split(list(image_flags)) #[(vit_batch_size / B, N, 448, 448)]*B
+
+
+ masks_flags = masks_flags.chunk(B)
+ token_features = []
+ input_embedings = []
+ masked_input_ids = []
+ masked_zero_bools = []
+ for i, vit_embed in enumerate(vit_embeds):
+ current_token = masks_flags[i].sum()
+ mask = masks[i]
+ limit_num = mask.shape[1]
+ mask = mask.permute(1,0,2,3).reshape(limit_num, -1) > 0
+ max_cluster_index = mask.sum(-1)
+ zero_bool = max_cluster_index != 0
+ # import pdb; pdb.set_trace()
+ mask[~zero_bool] = 1 #for addressing bflost16 bug
+ new_max_cluster_index = mask.sum(-1)
+ mask = mask / new_max_cluster_index.unsqueeze(-1)
+ token_feature = torch.matmul(mask.to(vit_embed), vit_embed.reshape(-1, vit_embed.shape[-1]))
+ token_features.extend(token_feature)
+ input_embedings.extend(input_embeds[i, :])
+ masked_input_ids.extend(input_ids[i, zero_bool])
+ masked_zero_bools.append(zero_bool)
+
+ masked_zero_bools = torch.cat(masked_zero_bools)
+ token_features = torch.stack(token_features)
+ input_embedings= torch.stack(input_embedings)
+
+ loss2 = F.mse_loss(token_features, input_embedings, reduction='none')[masked_zero_bools].sum(1).sqrt().mean()
+ token_features = token_features / token_features.norm(dim=1, keepdim=True)
+ input_embedings = input_embedings / input_embedings.norm(dim=1, keepdim=True)
+ # cosine similarity as logits
+ similarity = F.cosine_similarity(token_features, input_embedings, dim=1)
+ loss1 = (1 - similarity[masked_zero_bools]).mean()
+ # loss_d = loss1 + loss2
+ # if rank == 0:
+ # print(f'loss1:{loss_d}')
+
+ ###siglip
+ # masked_input_ids = torch.stack(masked_input_ids)
+ # label_matrix = (masked_input_ids.unsqueeze(0) == masked_input_ids.unsqueeze(1)).int()
+ # label_matrix = 2 * label_matrix - 1
+ # if self.kb:
+ # logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() + self.b.to(input_embedings.device)
+ # else:
+ # logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() - 8.9375
+ # loss_s = -torch.sum(F.logsigmoid(label_matrix * logits)) / logits.shape[0]
+ # if rank == 0:
+ # print(f'loss2:{loss_s}')
+ return loss1, loss2
+
+ def forward_tokenocr(self, pixel_values):
+ vit_embeds = self.backbone(pixel_values)
+ vit_embeds = vit_embeds['0']
+ h, w = vit_embeds.shape[2], vit_embeds.shape[3]
+ vit_embeds = self.upsample(vit_embeds)
+ vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-2] * vit_embeds.shape[-1])
+ vit_embeds = self.ocr_mlp(vit_embeds.permute(0, 2, 1))
+ return vit_embeds, (h*4, w*4)
+
+
+class MLP(nn.Module):
+ """ Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+def build(args):
+ backbone = build_backbone(args)
+ model = TokenOCR(backbone)
+ return model
diff --git a/tokenizer_path/special_tokens_map.json b/tokenizer_path/special_tokens_map.json
new file mode 100644
index 0000000000000000000000000000000000000000..14c079b57d740ab21ceced78982ee102ff4fea48
--- /dev/null
+++ b/tokenizer_path/special_tokens_map.json
@@ -0,0 +1,48 @@
+{
+ "additional_special_tokens": [
+ "<|im_start|>",
+ "<|im_end|>",
+ "<|action_start|>",
+ "<|action_end|>",
+ "<|interpreter|>",
+ "<|plugin|>",
+ "
",
+ "",
+ "",
+ "",
+ "",
+ "[",
+ "]",
+ "",
+ ""
+ ],
+ "bos_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "eos_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "pad_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ },
+ "unk_token": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false
+ }
+ }
+
\ No newline at end of file
diff --git a/tokenizer_path/tokenization_internlm2.py b/tokenizer_path/tokenization_internlm2.py
new file mode 100644
index 0000000000000000000000000000000000000000..1be581da37ef678de65f2737493fc0ed7160446e
--- /dev/null
+++ b/tokenizer_path/tokenization_internlm2.py
@@ -0,0 +1,235 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization classes for InternLM."""
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+from transformers.tokenization_utils import PreTrainedTokenizer
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
+
+PRETRAINED_VOCAB_FILES_MAP = {}
+
+
+# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
+class InternLM2Tokenizer(PreTrainedTokenizer):
+ """
+ Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ """
+
+ vocab_files_names = VOCAB_FILES_NAMES
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
+ model_input_names = ['input_ids', 'attention_mask']
+ _auto_class = 'AutoTokenizer'
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token='',
+ bos_token='',
+ eos_token='',
+ pad_token='',
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ add_bos_token=True,
+ add_eos_token=False,
+ decode_with_prefix_space=False,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+ self.vocab_file = vocab_file
+ self.add_bos_token = add_bos_token
+ self.add_eos_token = add_eos_token
+ self.decode_with_prefix_space = decode_with_prefix_space
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(vocab_file)
+ self._no_prefix_space_tokens = None
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ pad_token=pad_token,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+
+ @property
+ def no_prefix_space_tokens(self):
+ if self._no_prefix_space_tokens is None:
+ vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
+ self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith('▁')}
+ return self._no_prefix_space_tokens
+
+ @property
+ def vocab_size(self):
+ """Returns vocab size"""
+ return self.sp_model.get_piece_size()
+
+ @property
+ def bos_token_id(self) -> Optional[int]:
+ return self.sp_model.bos_id()
+
+ @property
+ def eos_token_id(self) -> Optional[int]:
+ return self.sp_model.eos_id()
+
+ def get_vocab(self):
+ """Returns vocab as a dict"""
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text):
+ """Returns a tokenized string."""
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ return self.sp_model.piece_to_id(token)
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ token = self.sp_model.IdToPiece(index)
+ return token
+
+ def _maybe_add_prefix_space(self, tokens, decoded):
+ if tokens and tokens[0] not in self.no_prefix_space_tokens:
+ return ' ' + decoded
+ else:
+ return decoded
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (string) in a single string."""
+ current_sub_tokens = []
+ out_string = ''
+ prev_is_special = False
+ for token in tokens:
+ # make sure that special tokens are not decoded using sentencepiece model
+ if token in self.all_special_tokens:
+ if not prev_is_special:
+ out_string += ' '
+ out_string += self.sp_model.decode(current_sub_tokens) + token
+ prev_is_special = True
+ current_sub_tokens = []
+ else:
+ current_sub_tokens.append(token)
+ prev_is_special = False
+ out_string += self.sp_model.decode(current_sub_tokens)
+ out_string = self.clean_up_tokenization(out_string)
+ out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
+ return out_string[1:]
+
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ """
+ Save the vocabulary and special tokens file to a directory.
+
+ Args:
+ save_directory (`str`):
+ The directory in which to save the vocabulary.
+
+ Returns:
+ `Tuple(str)`: Paths to the files saved.
+ """
+ if not os.path.isdir(save_directory):
+ logger.error(f'Vocabulary path ({save_directory}) should be a directory')
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, 'wb') as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
+
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
+ if self.add_bos_token:
+ bos_token_ids = [self.bos_token_id]
+ else:
+ bos_token_ids = []
+
+ output = bos_token_ids + token_ids_0
+
+ if token_ids_1 is not None:
+ output = output + token_ids_1
+
+ if self.add_eos_token:
+ output = output + [self.eos_token_id]
+
+ return output
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
+ use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+ """
+ eos = [self.eos_token_id]
+
+ if token_ids_1 is None:
+ return len(token_ids_0 + eos) * [0]
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
diff --git a/tokenizer_path/tokenization_internlm2_fast.py b/tokenizer_path/tokenization_internlm2_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa0fccbd0f1d029d79e19821f2edcb01b594537c
--- /dev/null
+++ b/tokenizer_path/tokenization_internlm2_fast.py
@@ -0,0 +1,211 @@
+# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tokenization Fast class for InternLM."""
+import os
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple
+
+from tokenizers import Tokenizer, decoders, normalizers, processors
+from tokenizers.models import BPE
+from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS,
+ SentencePieceExtractor,
+ SpmConverter)
+from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
+from transformers.utils import logging
+
+from .tokenization_internlm2 import InternLM2Tokenizer
+
+logger = logging.get_logger(__name__)
+
+VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
+
+
+# Modified from transformers.convert_slow_tokenizer.LlamaConverter
+class InternLM2Converter(SpmConverter):
+ handle_byte_fallback = True
+
+ def vocab(self, proto):
+ vocab = [
+ ('', 0.0),
+ ('', 0.0),
+ ('', 0.0),
+ ]
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
+ return vocab
+
+ def unk_id(self, proto):
+ unk_id = 0
+ return unk_id
+
+ def decoder(self, replacement, add_prefix_space):
+ return decoders.Sequence(
+ [
+ decoders.Replace('▁', ' '),
+ decoders.ByteFallback(),
+ decoders.Fuse(),
+ decoders.Strip(content=' ', left=1),
+ ]
+ )
+
+ def tokenizer(self, proto):
+ model_type = proto.trainer_spec.model_type
+ vocab_scores = self.vocab(proto)
+ # special tokens
+ added_tokens = self.original_tokenizer.added_tokens_decoder
+ for i in range(len(vocab_scores)):
+ piece, score = vocab_scores[i]
+ if i in added_tokens:
+ vocab_scores[i] = (added_tokens[i].content, score)
+ if model_type == 1:
+ raise RuntimeError('InternLM2 is supposed to be a BPE model!')
+
+ elif model_type == 2:
+ _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
+ bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
+ tokenizer = Tokenizer(
+ BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
+ )
+ tokenizer.add_special_tokens(
+ [ added_token for index, added_token in added_tokens.items()]
+ )
+ else:
+ raise Exception(
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
+ )
+
+ return tokenizer
+
+ def normalizer(self, proto):
+ normalizers_list = []
+ if proto.normalizer_spec.add_dummy_prefix:
+ normalizers_list.append(normalizers.Prepend(prepend='▁'))
+ normalizers_list.append(normalizers.Replace(pattern=' ', content='▁'))
+ return normalizers.Sequence(normalizers_list)
+
+ def pre_tokenizer(self, replacement, add_prefix_space):
+ return None
+
+
+SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter
+
+
+# Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
+class InternLM2TokenizerFast(PreTrainedTokenizerFast):
+ vocab_files_names = VOCAB_FILES_NAMES
+ slow_tokenizer_class = InternLM2Tokenizer
+ padding_side = 'left'
+ model_input_names = ['input_ids', 'attention_mask']
+ _auto_class = 'AutoTokenizer'
+
+ def __init__(
+ self,
+ vocab_file,
+ unk_token='',
+ bos_token='',
+ eos_token='',
+ pad_token='',
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ add_bos_token=True,
+ add_eos_token=False,
+ decode_with_prefix_space=False,
+ clean_up_tokenization_spaces=False,
+ **kwargs,
+ ):
+ super().__init__(
+ vocab_file=vocab_file,
+ unk_token=unk_token,
+ bos_token=bos_token,
+ eos_token=eos_token,
+ pad_token=pad_token,
+ sp_model_kwargs=sp_model_kwargs,
+ add_bos_token=add_bos_token,
+ add_eos_token=add_eos_token,
+ decode_with_prefix_space=decode_with_prefix_space,
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+ **kwargs,
+ )
+ self._add_bos_token = add_bos_token
+ self._add_eos_token = add_eos_token
+ self.update_post_processor()
+ self.vocab_file = vocab_file
+
+ @property
+ def can_save_slow_tokenizer(self) -> bool:
+ return os.path.isfile(self.vocab_file) if self.vocab_file else False
+
+ def update_post_processor(self):
+ """
+ Updates the underlying post processor with the current `bos_token` and `eos_token`.
+ """
+ bos = self.bos_token
+ bos_token_id = self.bos_token_id
+ if bos is None and self.add_bos_token:
+ raise ValueError('add_bos_token = True but bos_token = None')
+
+ eos = self.eos_token
+ eos_token_id = self.eos_token_id
+ if eos is None and self.add_eos_token:
+ raise ValueError('add_eos_token = True but eos_token = None')
+
+ single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
+ pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
+
+ special_tokens = []
+ if self.add_bos_token:
+ special_tokens.append((bos, bos_token_id))
+ if self.add_eos_token:
+ special_tokens.append((eos, eos_token_id))
+ self._tokenizer.post_processor = processors.TemplateProcessing(
+ single=single, pair=pair, special_tokens=special_tokens
+ )
+
+ @property
+ def add_eos_token(self):
+ return self._add_eos_token
+
+ @property
+ def add_bos_token(self):
+ return self._add_bos_token
+
+ @add_eos_token.setter
+ def add_eos_token(self, value):
+ self._add_eos_token = value
+ self.update_post_processor()
+
+ @add_bos_token.setter
+ def add_bos_token(self, value):
+ self._add_bos_token = value
+ self.update_post_processor()
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not self.can_save_slow_tokenizer:
+ raise ValueError(
+ 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
+ 'tokenizer.'
+ )
+
+ if not os.path.isdir(save_directory):
+ logger.error(f'Vocabulary path ({save_directory}) should be a directory')
+ return
+ out_vocab_file = os.path.join(
+ save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+
+ return (out_vocab_file,)
diff --git a/tokenizer_path/tokenizer.model b/tokenizer_path/tokenizer.model
new file mode 100644
index 0000000000000000000000000000000000000000..6600712949ca9c4ffb50f25275993a21fba0b408
--- /dev/null
+++ b/tokenizer_path/tokenizer.model
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f868398fc4e05ee1e8aeba95ddf18ddcc45b8bce55d5093bead5bbf80429b48b
+size 1477754
diff --git a/tokenizer_path/tokenizer_config.json b/tokenizer_path/tokenizer_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..d102412748394fbc89824d82a399710d3468e01f
--- /dev/null
+++ b/tokenizer_path/tokenizer_config.json
@@ -0,0 +1,180 @@
+{
+ "added_tokens_decoder": {
+ "0": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "1": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "2": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92538": {
+ "content": "<|plugin|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92539": {
+ "content": "<|interpreter|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92540": {
+ "content": "<|action_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92541": {
+ "content": "<|action_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92542": {
+ "content": "<|im_end|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92543": {
+ "content": "<|im_start|>",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92544": {
+ "content": "
",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92545": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92546": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92547": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92548": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92549": {
+ "content": "[",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92550": {
+ "content": "]",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92551": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ },
+ "92552": {
+ "content": "",
+ "lstrip": false,
+ "normalized": false,
+ "rstrip": false,
+ "single_word": false,
+ "special": true
+ }
+ },
+ "additional_special_tokens": [
+ "<|im_start|>",
+ "<|im_end|>",
+ "<|action_start|>",
+ "<|action_end|>",
+ "<|interpreter|>",
+ "<|plugin|>",
+ "
",
+ "",
+ "",
+ "",
+ "",
+ "[",
+ "]",
+ "",
+ ""
+ ],
+ "auto_map": {
+ "AutoTokenizer": [
+ "tokenization_internlm2.InternLM2Tokenizer",
+ null
+ ]
+ },
+ "bos_token": "",
+ "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+ "clean_up_tokenization_spaces": false,
+ "eos_token": "",
+ "model_max_length": 16384,
+ "pad_token": "",
+ "tokenizer_class": "InternLM2Tokenizer",
+ "unk_token": ""
+ }
+
\ No newline at end of file
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..074c42d61df8da4ffc6febcf68fd3fa5940b6f34
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,194 @@
+import os
+import torch
+import torch.nn as nn
+from internvl.train.dataset import build_transform, dynamic_preprocess
+from internvl.model.internvl_chat import InternVisionModel, InternVLChatModel
+from torchvision.utils import make_grid
+import torchvision.transforms as T
+import matplotlib.pyplot as plt
+import torch.nn.functional as F
+from transformers import AutoTokenizer, AutoModel, CLIPImageProcessor
+import cv2
+from PIL import Image
+
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
+CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
+SIGLIP_MEAN = (0.5, 0.5, 0.5)
+SIGLIP_STD = (0.5, 0.5, 0.5)
+IMG_CONTEXT_TOKEN = ''
+IMG_START_TOKEN = '
'
+IMG_END_TOKEN = ''
+QUAD_START_TOKEN = ''
+QUAD_END_TOKEN = ''
+REF_START_TOKEN = '['
+REF_END_TOKEN = ']'
+BOX_START_TOKEN = ''
+BOX_END_TOKEN = ''
+
+def load_model(config, state_dict):
+ vision_model = InternVisionModel(config.vision_config)
+ vit = InternVLChatModel(config, vision_model).to(torch.bfloat16)
+ vit.load_state_dict(state_dict, strict=False)
+ tok_embeddings = nn.Embedding(config.llm_config.vocab_size, config.llm_config.hidden_size, 2).to(torch.bfloat16)
+ tok_embeddings.weight = nn.Parameter(state_dict['language_model.model.tok_embeddings.weight'])
+ return vit, tok_embeddings
+
+def load_image(image_path):
+ transform = get_transform(is_train=False, image_size=448)
+ image = Image.open(image_path).convert('RGB')
+ images, target_aspect_ratio = dynamic_preprocess(image, min_num=1, max_num=12,
+ image_size=448, use_thumbnail=True, return_ratio=True)
+ pixel_values = [transform(image) for image in images]
+ pixel_values = torch.stack(pixel_values).to(torch.bfloat16)
+ return pixel_values, images, target_aspect_ratio
+
+def get_similarity_map(sm, shape, min_max=True, threshold=0.2):
+ B, N, H, W = sm.shape
+ sm = sm.reshape(B, N, H*W)
+ if min_max:
+ # min-max norm
+ sm = (sm - sm.min(2, keepdim=True)[0]) / (sm.max(2, keepdim=True)[0] - sm.min(2, keepdim=True)[0])
+ else:
+ sm = sm > threshold
+ sm = sm.float()
+ # reshape
+ sm = sm.reshape(B, N, H, W).float()
+ # interpolate
+ sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
+ return sm
+
+def build_transform_R50(normalize_type='imagenet'):
+ if normalize_type == 'imagenet':
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
+ elif normalize_type == 'clip':
+ MEAN, STD = CLIP_MEAN, CLIP_STD
+ elif normalize_type == 'siglip':
+ MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
+ else:
+ raise NotImplementedError
+ transform = T.Compose([
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+ T.ToTensor(),
+ T.Normalize(mean=MEAN, std=STD)
+ ])
+ return transform
+
+def load_tokenizer(tokenizer_path):
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_path, add_eos_token=False, trust_remote_code=True, use_fast=False)
+ tokenizer.tokenizer_path = tokenizer_path
+ tokenizer.model_max_length = 8192
+ token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
+ QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN,
+ REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN]
+ num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
+ return tokenizer
+
+
+def get_transform(is_train, image_size):
+ # Build transformation function
+ transform = build_transform(is_train=is_train, input_size=image_size,
+ pad2square=False, normalize_type='imagenet')
+ return transform
+
+def post_process(vit_embeds, target_aspect_ratio, model_type='VIT'):
+ if model_type in ["TokenOCR-4096-English-seg", "TokenOCR-2048-Bilingual-seg"]:
+ h = w = int(vit_embeds.shape[1] ** 0.5)
+ c = vit_embeds.shape[-1]
+ # vit_embeds_local = vit_embeds[:-1].reshape(-1, h, w, c).permute(0, 3, 1, 2)
+ if vit_embeds.shape[0] == 1:
+ vit_embeds_local = vit_embeds.reshape(-1, h, w, c).permute(0, 3, 1, 2)
+ else:
+ vit_embeds_local = vit_embeds[:-1].reshape(-1, h, w, c).permute(0, 3, 1, 2)
+ vit_embeds_local = make_grid(vit_embeds_local, nrow=target_aspect_ratio[0], padding=0, normalize=False)
+ vit_embeds_local = vit_embeds_local.permute(1,2,0)
+ H, W, C = vit_embeds_local.shape
+ vit_embeds_local = vit_embeds_local.reshape(H*W, C)
+ return vit_embeds_local, (H, W)
+ if 'R50' in model_type:
+ vit_embeds = vit_embeds.reshape(-1, vit_embeds.shape[-1])
+ return vit_embeds, None
+
+def generate_similiarity_map(images, attn_map, all_bpe_strings, vis_list, target_aspect_ratio=(1,1), src_iamge_size=(1014, 1024), image_size=448):
+ # if isinstance(images, list):
+ # print("111111111")
+ # if len(images) == 1:
+ # images_vis = torch.stack([T.ToTensor()(image) for image in images])
+ # else:
+ # images_vis = torch.stack([T.ToTensor()(image) for image in images[:-1]])
+
+ # images_vis = make_grid(images_vis, nrow=target_aspect_ratio[0], padding=0, normalize=False)
+ # print("image_size",image_size)
+ # print("target_aspect_ratio[0]",target_aspect_ratio[0])
+ # print("target_aspect_ratio[1]",target_aspect_ratio[1])
+ # target_width = image_size * target_aspect_ratio[0]
+ # target_height = image_size * target_aspect_ratio[1]
+
+ # else:
+ # print("222222222")
+
+ # images_vis = T.ToTensor()(images)
+ # target_height = images.size[1]
+ # target_width = images.size[0]
+ # print("images_vis",images_vis)
+ # print("target_height",images.size[1])
+ # print("target_width",images.size[0])
+
+ images = images[0]
+ images_vis = T.ToTensor()(images) # images []
+ print("images",images)
+ print("images_vis",images_vis)
+ target_height = images.size[1]
+ target_width = images.size[0]
+
+
+ print("attn_map",attn_map.shape)# torch.Size([4, 76, 128])
+ print("target_height",target_height) #有问题 608
+ print("target_width",target_width) #有问题 1024
+
+ attn_norm = get_similarity_map(attn_map.unsqueeze(0), (target_height, target_width), min_max=True, threshold=0.15)
+ print("attn_norm ",attn_norm.shape) # 有问题attn_norm torch.Size([1, 4, 448, 448])
+ print('all_bpe_strings:{:}'.format(all_bpe_strings))
+ indexes_without_space = torch.tensor([index for index, string in enumerate(all_bpe_strings) if ' ' is not string])
+
+ # Draw similarity map
+ # print(images_vis.shape)
+ images_vis = (images_vis.permute(1,2,0).cpu().numpy() * 125).astype('uint8')
+ for b in range(attn_norm.shape[0]):
+ for n in range(attn_norm.shape[1]-1):
+ vis = (attn_norm[b, n, :, :].float().detach().cpu().numpy() * 255).astype('uint8')
+ vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET)
+ print("images_vis",images_vis.shape)
+ print("vis",vis.shape)
+ vis = images_vis * 0.5 + vis * 0.5
+ vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB)
+ vis = cv2.resize(vis, src_iamge_size)
+ vis_list.append(vis) # Add each visualization to the list
+
+ without_space_norm = attn_norm[b, indexes_without_space, :, :].max(0)[0]
+ space_norm = attn_norm[b, -1, :, :]
+ all_attn_norm = without_space_norm - space_norm
+ print(f'min:{all_attn_norm.min()};max:{all_attn_norm.max()}')
+ all_attn_norm = (all_attn_norm - all_attn_norm.min()) / (all_attn_norm.max() - all_attn_norm.min())
+ all_attn_norm = (all_attn_norm.float().detach().cpu().numpy() * 255).astype('uint8')
+ vis = cv2.applyColorMap(all_attn_norm, cv2.COLORMAP_JET)
+ vis = images_vis * 0.5 + vis * 0.5
+ vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB)
+ vis = cv2.resize(vis, src_iamge_size)
+ vis_list.append(vis) # Add each visualization to the list
+
+ return vis_list
+
+
+def load_model_and_tokenizer_customed(checkpoint):
+ kwargs = {}
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True, use_fast=False)
+ model = InternVLChatModel.from_pretrained(
+ checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
+ load_in_8bit=False, load_in_4bit=False, **kwargs).eval()
+ del model.language_model.model.layers
+ del model.language_model.output
+ model = model.cuda()
+ return model, tokenizer
\ No newline at end of file