diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..41d2271826c4932cc927e5e299c8f4146eda704a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/examples0.jpg filter=lfs diff=lfs merge=lfs -text +examples/examples1.jpg filter=lfs diff=lfs merge=lfs -text +examples/examples2.png filter=lfs diff=lfs merge=lfs -text diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c7cc3557d13dcb8bbcba160be2af18942e12d163 --- /dev/null +++ b/app.py @@ -0,0 +1,195 @@ +import os +import argparse +import numpy as np +from PIL import Image +import torch +import torchvision.transforms as T +from transformers import AutoTokenizer +import gradio as gr +from resnet50 import build_model +from utils import generate_similiarity_map, post_process, load_tokenizer, build_transform_R50 +from utils import IMAGENET_MEAN, IMAGENET_STD +from internvl.train.dataset import dynamic_preprocess +from internvl.model.internvl_chat import InternVLChatModel + +# 模型配置 +CHECKPOINTS = { + "TokenOCR-4096-English-seg": "/path/to/TokenOCR_4096_English_seg", + "TokenOCR-2048-Bilingual-seg": "/path/to/TokenOCR_2048_Binlinual_seg", + "R50":"model/checkpoint.pth", + "R50_siglip": "/path/to/R50_siglip_checkpoint.pth" +} + +# 全局变量 +current_vis = [] +current_bpe = [] +current_index = 0 + +def load_model(check_type): + device = torch.device("cpu") + + if check_type == 'R50': + tokenizer = load_tokenizer('tokenizer_path') + model = build_model(argparse.Namespace()).eval() + model.load_state_dict(torch.load(CHECKPOINTS['R50'], map_location='cpu')['model']) + transform = build_transform_R50(normalize_type='imagenet') + + elif check_type == 'R50_siglip': + tokenizer = load_tokenizer('tokenizer_path') + model = build_model(argparse.Namespace()).eval() + model.load_state_dict(torch.load(CHECKPOINTS['R50_siglip'], map_location='cpu')['model']) + transform = build_transform_R50(normalize_type='imagenet') + + elif 'TokenOCR' in check_type: + model_path = CHECKPOINTS[check_type] + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False) + model = InternVLChatModel.from_pretrained(model_path, torch_dtype=torch.bfloat16).eval() + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB')), + T.Resize((224, 224)), + T.ToTensor(), + T.Normalize(IMAGENET_MEAN, IMAGENET_STD) + ]) + + return model.to(device), tokenizer, transform, device + +def process_image(model, tokenizer, transform, device, check_type, image, text): + global current_vis, current_bpe + src_size = image.size + if 'TokenOCR' in check_type: + images, target_ratio = dynamic_preprocess(image, min_num=1, max_num=12, + image_size=model.config.force_image_size, + use_thumbnail=model.config.use_thumbnail, + return_ratio=True) + pixel_values = torch.stack([transform(img) for img in images]).to(device) + else: + pixel_values = torch.stack([transform(image)]).to(device) + target_ratio = (1, 1) + + # 文本处理 + text += ' ' + input_ids = tokenizer(text)['input_ids'][1:] + input_ids = torch.tensor(input_ids, device=device) + + # 获取嵌入 + with torch.no_grad(): + if 'R50' in check_type: + text_embeds = model.language_embedding(input_ids) + else: + text_embeds = model.tok_embeddings(input_ids) + + vit_embeds, size1 = model.forward_tokenocr(pixel_values.to(device)) + vit_embeds, size2 = post_process(vit_embeds, target_ratio, check_type) + + # 计算相似度 + text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) + vit_embeds = vit_embeds / vit_embeds.norm(dim=-1, keepdim=True) + similarity = text_embeds @ vit_embeds.T + resized_size = size1 if size1 is not None else size2 + + # print(f"text_embeds shape: {text_embeds.shape}, numel: {text_embeds.numel()}") # text_embeds shape: torch.Size([4, 2048]), numel: 8192 + # print(f"vit_embeds shape: {vit_embeds.shape}, numel: {vit_embeds.numel()}") # vit_embeds shape: torch.Size([9728, 2048]), numel: 19922944 + # print(f"similarity shape: {similarity.shape}, numel: {similarity.numel()}")# similarity shape: torch.Size([4, 9728]), numel: 38912 + + + # 生成可视化 + attn_map = similarity.reshape(len(text_embeds), resized_size[0], resized_size[1]) + # attn_map = similarity.reshape(len(text_embeds), *target_ratio) + all_bpe_strings = [tokenizer.decode(input_id) for input_id in input_ids] + current_vis = generate_similiarity_map([image], attn_map, + [tokenizer.decode([i]) for i in input_ids], + [], target_ratio, src_size) + + current_bpe = [tokenizer.decode([i]) for i in input_ids] + # current_bpe[-1] = 'Input text' + current_bpe[-1] = text + return image, current_vis[0], current_bpe[0] + +# 事件处理函数 +def update_index(change): + global current_index + current_index = max(0, min(len(current_vis) - 1, current_index + change)) + return current_vis[current_index], format_bpe_display(current_bpe[current_index]) + +def format_bpe_display(bpe): + # 使用HTML标签来设置字体大小、颜色,加粗,并居中 + return f"
Current BPE: {bpe}
" + +# Gradio界面 +with gr.Blocks(title="BPE Visualization Demo") as demo: + gr.Markdown("## BPE Visualization Demo - TokenOCR基座模型能力可视化") + + with gr.Row(): + with gr.Column(scale=0.5): + model_type = gr.Dropdown( + choices=["TokenOCR-4096-English-seg", "TokenOCR-2048-Bilingual-seg", "R50", "R50_siglip"], + label="Select model type", + value="R50" # 设置默认值为第一个选项 + ) + image_input = gr.Image(label="Upload images", type="pil") + text_input = gr.Textbox(label="Input text") + + run_btn = gr.Button("RUN") + + gr.Examples( + examples=[ + [os.path.join("examples", "examples0.jpg"), "Veterans and Benefits"], + [os.path.join("examples", "examples1.jpg"), "Refreshers"], + [os.path.join("examples", "examples2.png"), "Vision Transformer"] + ], + inputs=[image_input, text_input], + label="Sample input" + ) + + with gr.Column(scale=2): + gr.Markdown("

If the input text is not included in the image, the attention map will show a lot of noise (the actual response value is very low), since we normalize the attention map according to the relative value.

") + + with gr.Row(): + orig_img = gr.Image(label="Original picture", interactive=False) + heatmap = gr.Image(label="BPE visualization", interactive=False) + + with gr.Row() as controls: + prev_btn = gr.Button("⬅ Last", visible=False) + index_slider = gr.Slider(0, 1, value=0, step=1, label="BPE index", visible=False) + next_btn = gr.Button("⮕ Next", visible=False) + + bpe_display = gr.Markdown("Current BPE: ", visible=False) + + # 事件处理 + def on_run_clicked(model_type, image, text): + global current_vis, current_bpe, current_index + current_index = 0 # Reset index when new image is processed + image, vis, bpe = process_image(*load_model(model_type), model_type, image, text) + # Update the slider range and set value to 0 + slider_max_val = len(current_bpe) - 1 + bpe_text = format_bpe_display(bpe) + return image, vis, bpe_text, slider_max_val + + run_btn.click( + on_run_clicked, + inputs=[model_type, image_input, text_input], + outputs=[orig_img, heatmap, bpe_display, index_slider], + ).then( + lambda max_val: (gr.update(visible=True), gr.update(visible=True, maximum=max_val, value=0), gr.update(visible=True), gr.update(visible=True)), + inputs=index_slider, + outputs=[prev_btn, index_slider, next_btn, bpe_display], + ) + + prev_btn.click( + lambda: (*update_index(-1), current_index), + outputs=[heatmap, bpe_display, index_slider] + ) + + next_btn.click( + lambda: (*update_index(1), current_index), + outputs=[heatmap, bpe_display, index_slider] + ) + + index_slider.change( + lambda x: (current_vis[x], format_bpe_display(current_bpe[x])), + inputs=index_slider, + outputs=[heatmap, bpe_display] + ) + +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/examples/examples0.jpg b/examples/examples0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..89ce15bfdc800ec5851f17369f7e1bbfc8e2f8a6 --- /dev/null +++ b/examples/examples0.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:afdea4f2f6523d2a8f9a10ed010ea57e257908fa284485b224c276310f723121 +size 423018 diff --git a/examples/examples1.jpg b/examples/examples1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fdd23474194746ea432b850ecd53e6f291f0818e --- /dev/null +++ b/examples/examples1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:347fe59f49c31c93f78a913c891eaeeeace59851f69dfb2c2631dfc7dd6c3886 +size 571030 diff --git a/examples/examples2.png b/examples/examples2.png new file mode 100644 index 0000000000000000000000000000000000000000..7fb40590b412ee234aa0298c27e7c5bfa8a13fd1 --- /dev/null +++ b/examples/examples2.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8066fab0957545e840369338b2693136f60cecf61a6eb85a9c650ad1e3f69e27 +size 1470171 diff --git a/internvl/__pycache__/conversation.cpython-310.pyc b/internvl/__pycache__/conversation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfbde9fd800528506c3b2af19379db5842b6a883 Binary files /dev/null and b/internvl/__pycache__/conversation.cpython-310.pyc differ diff --git a/internvl/__pycache__/conversation.cpython-39.pyc b/internvl/__pycache__/conversation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de0bcb38f21f0c673cbe4b936c46203f6e941887 Binary files /dev/null and b/internvl/__pycache__/conversation.cpython-39.pyc differ diff --git a/internvl/__pycache__/dist_utils.cpython-39.pyc b/internvl/__pycache__/dist_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8aef188afd67ddb6ecd1cf8717269190b08a8d1f Binary files /dev/null and b/internvl/__pycache__/dist_utils.cpython-39.pyc differ diff --git a/internvl/conversation.py b/internvl/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..cf9e80bebbd5342f313fbef61c2b21194157a0db --- /dev/null +++ b/internvl/conversation.py @@ -0,0 +1,402 @@ +""" +Conversation prompt templates. + +We kindly request that you import fastchat instead of copying this file if you wish to use it. +If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. +""" + +import dataclasses +from enum import IntEnum, auto +from typing import Any, Dict, List, Tuple, Union + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + FALCON_CHAT = auto() + CHATGLM3 = auto() + INTERNVL_ZH = auto() + MPT = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = '{system_message}' + # The system message + system_message: str = '' + # The names of two roles + roles: Tuple[str] = ('USER', 'ASSISTANT') + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = '\n' + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + # Stops generation if meeting any token in this list + stop_token_ids: List[int] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ': ' + message + seps[i % 2] + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ': ' # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = '' if system_prompt == '' else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + '\n' + message + self.sep + else: + ret += role + '\n' + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += ( + role + + ': ' + + message.replace('\r\n', '\n').replace('\n\n', '\n') + ) + ret += '\n\n' + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = '[INST] ' + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + ' ' + else: + ret += tag + ' ' + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == 'chatglm2' else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = '' + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f'[Round {i//2 + round_add_n}]{self.sep}' + + if message: + ret += f'{role}:{message}{self.sep}' + else: + ret += f'{role}:' + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = '' if system_prompt == '' else system_prompt + self.sep + '\n' + for role, message in self.messages: + if message: + ret += role + '\n' + message + self.sep + '\n' + else: + ret += role + '\n' + return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = '' + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + '\n' + ' ' + message + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + # if i % 2 == 0: + # ret += "" + if message: + ret += role + ':' + message + seps[i % 2] + '\n' + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ':\n' + message + seps[i % 2] + if i % 2 == 1: + ret += '\n\n' + else: + ret += role + ':\n' + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + ': ' + '' + message + '' + else: + ret += role + ': ' + '' + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ':\n' + message + self.sep + else: + ret += role + ':\n' + return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = '' + if self.system_message: + ret += system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ': ' + message + self.sep + else: + ret += role + ':' + + return ret + elif self.sep_style == SeparatorStyle.INTERNVL_ZH: + seps = [self.sep2, self.sep] + ret = self.system_message + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ': ' + message + seps[i % 2] + else: + ret += role + ':' + return ret + elif self.sep_style == SeparatorStyle.MPT: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + self.sep + else: + ret += role + return ret + else: + raise ValueError(f'Invalid style: {self.sep_style}') + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + ret = [{'role': 'system', 'content': self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({'role': 'user', 'content': msg}) + else: + if msg is not None: + ret.append({'role': 'assistant', 'content': msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + stop_token_ids=self.stop_token_ids, + ) + + def dict(self): + return { + 'template_name': self.name, + 'system_message': self.system_message, + 'roles': self.roles, + 'messages': self.messages, + 'offset': self.offset, + } + + +# A global registry for all conversation templates +conv_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert ( + template.name not in conv_templates + ), f'{template.name} has been registered.' + + conv_templates[template.name] = template + + +def get_conv_template(name: str) -> Conversation: + """Get a conversation template.""" + return conv_templates[name].copy() + + +# InternVL-Chat-V1-1 template +register_conv_template( + Conversation( + name='internvl_zh', + system_template='', + roles=('', ''), + sep_style=SeparatorStyle.INTERNVL_ZH, + sep='', + sep2=' ', + ) +) + + +# Both Hermes-2 and internlm2-chat are chatml-format conversation templates. The difference +# is that during training, the preprocessing function for the Hermes-2 template doesn't add +# at the beginning of the tokenized sequence, while the internlm2-chat template does. +# Therefore, they are completely equivalent during inference. +register_conv_template( + Conversation( + name='Hermes-2', + system_template='<|im_start|>system\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>', + stop_str='<|endoftext|>', + ) +) + + +register_conv_template( + Conversation( + name='internlm2-chat', + system_template='<|im_start|>system\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>', + ) +) + + +register_conv_template( + Conversation( + name='phi3-chat', + system_template='<|system|>\n{system_message}', + # note: The new system prompt was not used here to avoid changes in benchmark performance. + # system_message='我是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + system_message='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。', + roles=('<|user|>\n', '<|assistant|>\n'), + sep_style=SeparatorStyle.MPT, + sep='<|end|>', + ) +) + + +register_conv_template( + Conversation( + name='internvl2_5', + system_template='<|im_start|>system\n{system_message}', + system_message='你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', + roles=('<|im_start|>user\n', '<|im_start|>assistant\n'), + sep_style=SeparatorStyle.MPT, + sep='<|im_end|>\n', + ) +) diff --git a/internvl/dist_utils.py b/internvl/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb8ae27731968e1acc11134b6c204b6d3c39afa --- /dev/null +++ b/internvl/dist_utils.py @@ -0,0 +1,104 @@ +import os +import socket +import subprocess +from datetime import timedelta + +import deepspeed +import torch +import torch.multiprocessing as mp +from torch import distributed as dist + +timeout = timedelta(minutes=60) + + +def _find_free_port(): + # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(('', 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def _is_free_port(port): + ips = socket.gethostbyname_ex(socket.gethostname())[-1] + ips.append('localhost') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return all(s.connect_ex((ip, port)) != 0 for ip in ips) + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + # dist.init_process_group(backend=backend, **kwargs) + deepspeed.init_distributed(dist_backend=backend) + + +def _init_dist_mpi(backend, **kwargs): + local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + torch.cuda.set_device(local_rank) + if 'MASTER_PORT' not in os.environ: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + if 'MASTER_ADDR' not in os.environ: + raise KeyError('The environment variable MASTER_ADDR is not set') + os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] + os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # if torch.distributed default port(29500) is available + # then use it, else find a free port + if _is_free_port(29500): + os.environ['MASTER_PORT'] = '29500' + else: + os.environ['MASTER_PORT'] = str(_find_free_port()) + # use MASTER_ADDR in the environment variable if it already exists + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + # dist.init_process_group(backend=backend, timeout=timeout) + deepspeed.init_distributed(dist_backend=backend) diff --git a/internvl/model/__init__.py b/internvl/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c006f6d39c59b0d98ad820994ba15b09bf96539b --- /dev/null +++ b/internvl/model/__init__.py @@ -0,0 +1,66 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import math + +import torch +from internvl.model.internvl_chat import InternVLChatConfig, InternVLChatModel +from transformers import AutoTokenizer + + +def split_model(num_layers, vit_alpha=0.5): + device_map = {} + world_size = torch.cuda.device_count() + # Since the first GPU will be used for ViT, treat it as half a GPU. + num_layers_per_gpu = math.ceil(num_layers / (world_size - vit_alpha)) + num_layers_per_gpu = [num_layers_per_gpu] * world_size + num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * (1 - vit_alpha)) + layer_cnt = 0 + for i, num_layer in enumerate(num_layers_per_gpu): + for j in range(num_layer): + device_map[f'language_model.model.layers.{layer_cnt}'] = i + layer_cnt += 1 + device_map['vision_model'] = 0 + device_map['mlp1'] = 0 + device_map['language_model.model.tok_embeddings'] = 0 + device_map['language_model.model.embed_tokens'] = 0 + device_map['language_model.output'] = 0 + device_map['language_model.model.norm'] = 0 + device_map['language_model.lm_head'] = 0 + device_map[f'language_model.model.layers.{num_layers - 1}'] = 0 + + return device_map + + +def load_model_and_tokenizer(args): + if args.auto: + config = InternVLChatConfig.from_pretrained(args.checkpoint) + num_hidden_layers = config.llm_config.num_hidden_layers + device_map = split_model(num_hidden_layers) + kwargs = {'device_map': device_map} if args.auto else {} + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) + model = InternVLChatModel.from_pretrained( + args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, + load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval() + if not args.load_in_8bit and not args.load_in_4bit and not args.auto: + model = model.cuda() + return model, tokenizer + +def load_model_and_tokenizer_customed(args): + if args.auto: + config = InternVLChatConfig.from_pretrained(args.checkpoint) + num_hidden_layers = config.llm_config.num_hidden_layers + device_map = split_model(num_hidden_layers) + kwargs = {'device_map': device_map} if args.auto else {} + tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) + model = InternVLChatModel.from_pretrained( + args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, + load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval() + if not args.load_in_8bit and not args.load_in_4bit and not args.auto: + del model.language_model.model.layers + del model.language_model.output + return model, tokenizer + diff --git a/internvl/model/__pycache__/__init__.cpython-310.pyc b/internvl/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14986ed766a1750e26dd590cb173849b754a6b2f Binary files /dev/null and b/internvl/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/internvl/model/__pycache__/__init__.cpython-39.pyc b/internvl/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4ffe980293aabb7ab792d4ad980f774f0771484 Binary files /dev/null and b/internvl/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-310.pyc b/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0412df00467d99bdd04fa62d05119a3205298158 Binary files /dev/null and b/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-310.pyc differ diff --git a/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-39.pyc b/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b5eace8fe3afa226523803b99c244bdb24ca30c Binary files /dev/null and b/internvl/model/internlm2/__pycache__/configuration_internlm2.cpython-39.pyc differ diff --git a/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-310.pyc b/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..429e50e05cfa3d5b1206e6070808520cef3fa7e0 Binary files /dev/null and b/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-310.pyc differ diff --git a/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-39.pyc b/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5a5e246572a6541734175afe216efe2f193aee3 Binary files /dev/null and b/internvl/model/internlm2/__pycache__/modeling_internlm2.cpython-39.pyc differ diff --git a/internvl/model/internlm2/configuration_internlm2.py b/internvl/model/internlm2/configuration_internlm2.py new file mode 100644 index 0000000000000000000000000000000000000000..282b13b1e2066ecc074ecae87b35a19d251f0ed7 --- /dev/null +++ b/internvl/model/internlm2/configuration_internlm2.py @@ -0,0 +1,150 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/configuration_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" InternLM2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +# Modified from transformers.model.llama.configuration_llama.LlamaConfig +class InternLM2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate + an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`InternLM2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + + """ + model_type = 'internlm2' + _auto_class = 'AutoConfig' + + def __init__( # pylint: disable=W0102 + self, + vocab_size=103168, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act='silu', + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + bias=True, + rope_theta=10000, + rope_scaling=None, + attn_implementation='eager', + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.bias = bias + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + + self.attn_implementation = attn_implementation + if self.attn_implementation is None: + self.attn_implementation = 'eager' + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, ' + f'got {self.rope_scaling}' + ) + rope_scaling_type = self.rope_scaling.get('type', None) + rope_scaling_factor = self.rope_scaling.get('factor', None) + if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}") diff --git a/internvl/model/internlm2/modeling_internlm2.py b/internvl/model/internlm2/modeling_internlm2.py new file mode 100644 index 0000000000000000000000000000000000000000..569513dffad6a2bce63c26d7463f90fbcc289b2c --- /dev/null +++ b/internvl/model/internlm2/modeling_internlm2.py @@ -0,0 +1,1429 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/modeling_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch InternLM2 model.""" +import math +import queue +import threading +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (add_start_docstrings, + add_start_docstrings_to_model_forward, logging, + replace_return_docstrings) + +try: + from transformers.generation.streamers import BaseStreamer +except: # noqa # pylint: disable=bare-except + BaseStreamer = None + +from .configuration_internlm2 import InternLM2Config + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = 'InternLM2Config' + +flash_attn_func, flash_attn_varlen_func = None, None +pad_input, index_first_axis, unpad_input = None, None, None +try: + from flash_attn import flash_attn_func as _flash_attn_func + from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis as _index_first_axis + from flash_attn.bert_padding import pad_input as _pad_input + from flash_attn.bert_padding import unpad_input as _unpad_input + + flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func + pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input + has_flash_attn = True +except: + has_flash_attn = False + + +def _import_flash_attn(): + global flash_attn_func, flash_attn_varlen_func + global pad_input, index_first_axis, unpad_input + try: + from flash_attn import flash_attn_func as _flash_attn_func + from flash_attn import \ + flash_attn_varlen_func as _flash_attn_varlen_func + from flash_attn.bert_padding import \ + index_first_axis as _index_first_axis + from flash_attn.bert_padding import pad_input as _pad_input + from flash_attn.bert_padding import unpad_input as _unpad_input + flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func + pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input + except ImportError: + raise ImportError('flash_attn is not installed.') + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2 +class InternLM2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + InternLM2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +try: + from functools import partial + + from apex.normalization import FusedRMSNorm + InternLM2RMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa + print('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternLM2RMSNorm') +except ImportError: + # using the normal LlamaRMSNorm + pass +except Exception: + print('discovered apex but it failed to load, falling back to InternLM2RMSNorm') + pass + + +# Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2 +class InternLM2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype) + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2 +class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2 +class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding): + """InternLM2RotaryEmbedding extended with Dynamic NTK scaling. + Credits to the Reddit users /u/bloc97 and /u/emozilla. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype) + + freqs = torch.einsum('i,j->ij', t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False) + self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.model.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class InternLM2MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.w2(self.act_fn(self.w1(x)) * self.w3(x)) + + return down_proj + + +# Copied from transformers.model.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Modified from transformers.model.llama.modeling_llama.LlamaAttention +class InternLM2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InternLM2Config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' + f' and `num_heads`: {self.num_heads}).' + ) + + self.wqkv = nn.Linear( + self.hidden_size, + (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim, + bias=config.bias, + ) + + self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = InternLM2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling['type'] + scaling_factor = self.config.rope_scaling['factor'] + if scaling_type == 'dynamic': + self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + scaling_factor=scaling_factor, + ) + elif scaling_type == 'linear': + self.rotary_emb = InternLM2LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + scaling_factor=scaling_factor, + ) + else: + raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.") + return self.rotary_emb + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' + 'Please make sure use `attention_mask` instead.`' + ) + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' + f' {attn_weights.size()}' + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}' + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2 +class InternLM2FlashAttention2(InternLM2Attention): + """ + InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # InternLM2FlashAttention2 attention does not support output_attentions + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' + 'Please make sure use `attention_mask` instead.`' + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.wqkv(hidden_states) + + qkv_states = rearrange( + qkv_states, + 'b q (h gs d) -> b q h gs d', + gs=2 + self.num_key_value_groups, + d=self.head_dim, + ) + + query_states = qkv_states[..., : self.num_key_value_groups, :] + query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') + key_states = qkv_states[..., -2, :] + value_states = qkv_states[..., -1, :] + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len + ) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.wo(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + causal = self.is_causal and query_length != 1 + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q.to(torch.int64), + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +INTERNLM2_ATTENTION_CLASSES = { + 'eager': InternLM2Attention, + 'flash_attention_2': InternLM2FlashAttention2, +} + + +# Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer +class InternLM2DecoderLayer(nn.Module): + def __init__(self, config: InternLM2Config): + super().__init__() + self.hidden_size = config.hidden_size + + self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config) + + self.feed_forward = InternLM2MLP(config) + self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. ' + 'Please make sure use `attention_mask` instead.`' + ) + + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +InternLM2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`InternLM2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2 +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2PreTrainedModel(PreTrainedModel): + config_class = InternLM2Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['InternLM2DecoderLayer'] + _skip_keys_device_placement = 'past_key_values' + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +InternLM2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Modified from transformers.model.llama.modeling_llama.LlamaModel +@add_start_docstrings( + 'The bare InternLM2 Model outputting raw hidden-states without any specific head on top.', + InternLM2_START_DOCSTRING, +) +class InternLM2Model(InternLM2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLM2DecoderLayer`] + + Args: + config: InternLM2Config + """ + + _auto_class = 'AutoModel' + + def __init__(self, config: InternLM2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + if not has_flash_attn: + self.config.attn_implementation = 'eager' + print('Warning: Flash attention is not available, using eager attention instead.') + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.tok_embeddings + + def set_input_embeddings(self, value): + self.tok_embeddings = value + + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.attn_implementation == 'flash_attention_2': + _import_flash_attn() + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError('You have to specify either input_ids or inputs_embeds') + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.tok_embeddings(input_ids) + + if self.config.attn_implementation == 'flash_attention_2': + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +# Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM +class InternLM2ForCausalLM(InternLM2PreTrainedModel): + _auto_class = 'AutoModelForCausalLM' + + _tied_weights_keys = ['output.weight'] + + def __init__(self, config): + super().__init__(config) + self.model = InternLM2Model(config) + self.vocab_size = config.vocab_size + self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + def get_output_embeddings(self): + return self.output + + def set_output_embeddings(self, new_embeddings): + self.output = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, InternLM2ForCausalLM + + >>> model = InternLM2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.output(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + device = input_ids.device if input_ids is not None else inputs_embeds.device + output = CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + output['logits'] = output['logits'].to(device) + return output + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1]:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update( + { + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=''): + if tokenizer.add_bos_token: + prompt = '' + else: + prompt = tokenizer.bos_token + if meta_instruction: + prompt += f"""<|im_start|>system\n{meta_instruction}<|im_end|>\n""" + for record in history: + prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n""" + prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n""" + return tokenizer([prompt], return_tensors='pt') + + @torch.no_grad() + def chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + streamer: Optional[BaseStreamer] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + meta_instruction: str = 'You are an AI assistant whose name is InternLM (书生·浦语).\n' + '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n' + '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.', + **kwargs, + ): + inputs = self.build_inputs(tokenizer, query, history, meta_instruction) + inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} + # also add end-of-assistant token in eos token id to avoid unnecessary generation + eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(['<|im_end|>'])[0]] + outputs = self.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + eos_token_id=eos_token_id, + **kwargs, + ) + outputs = outputs[0].cpu().tolist()[len(inputs['input_ids'][0]):] + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split('<|im_end|>')[0] + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs, + ): + """ + Return a generator in format: (response, history) + Eg. + ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) + ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) + """ + if BaseStreamer is None: + raise ModuleNotFoundError( + 'The version of `transformers` is too low. Please make sure ' + 'that you have installed `transformers>=4.28.0`.' + ) + + response_queue = queue.Queue(maxsize=20) + + class ChatStreamer(BaseStreamer): + def __init__(self, tokenizer) -> None: + super().__init__() + self.tokenizer = tokenizer + self.queue = response_queue + self.query = query + self.history = history + self.response = '' + self.cache = [] + self.received_inputs = False + self.queue.put((self.response, history + [(self.query, self.response)])) + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError('ChatStreamer only supports batch size 1') + elif len(value.shape) > 1: + value = value[0] + + if not self.received_inputs: + # The first received value is input_ids, ignore here + self.received_inputs = True + return + + self.cache.extend(value.tolist()) + token = self.tokenizer.decode(self.cache, skip_special_tokens=True) + if token.strip() != '<|im_end|>': + self.response = self.response + token + history = self.history + [(self.query, self.response)] + self.queue.put((self.response, history)) + self.cache = [] + else: + self.end() + + def end(self): + self.queue.put(None) + + def stream_producer(): + return self.chat( + tokenizer=tokenizer, + query=query, + streamer=ChatStreamer(tokenizer=tokenizer), + history=history, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + def consumer(): + producer = threading.Thread(target=stream_producer) + producer.start() + while True: + res = response_queue.get() + if res is None: + return + yield res + + return consumer() + + +# Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2 +@add_start_docstrings( + """ + The InternLM2 Model transformer with a sequence classification head on top (linear layer). + + [`InternLM2ForSequenceClassification`] uses the last token in order to do the classification, + as other causal models (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + InternLM2_START_DOCSTRING, +) +class InternLM2ForSequenceClassification(InternLM2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = InternLM2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.tok_embeddings + + def set_input_embeddings(self, value): + self.model.tok_embeddings = value + + @add_start_docstrings_to_model_forward(InternLM2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.') + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/internvl/model/internlm2/tokenization_internlm2.py b/internvl/model/internlm2/tokenization_internlm2.py new file mode 100644 index 0000000000000000000000000000000000000000..1be581da37ef678de65f2737493fc0ed7160446e --- /dev/null +++ b/internvl/model/internlm2/tokenization_internlm2.py @@ -0,0 +1,235 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/tokenization_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for InternLM.""" +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'} + +PRETRAINED_VOCAB_FILES_MAP = {} + + +# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer +class InternLM2Tokenizer(PreTrainedTokenizer): + """ + Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ['input_ids', 'attention_mask'] + _auto_class = 'AutoTokenizer' + + def __init__( + self, + vocab_file, + unk_token='', + bos_token='', + eos_token='', + pad_token='', + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + decode_with_prefix_space=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.decode_with_prefix_space = decode_with_prefix_space + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + self._no_prefix_space_tokens = None + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def no_prefix_space_tokens(self): + if self._no_prefix_space_tokens is None: + vocab = self.convert_ids_to_tokens(list(range(self.vocab_size))) + self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith('▁')} + return self._no_prefix_space_tokens + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + @property + def bos_token_id(self) -> Optional[int]: + return self.sp_model.bos_id() + + @property + def eos_token_id(self) -> Optional[int]: + return self.sp_model.eos_id() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def _maybe_add_prefix_space(self, tokens, decoded): + if tokens and tokens[0] not in self.no_prefix_space_tokens: + return ' ' + decoded + else: + return decoded + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = '' + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += ' ' + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + out_string = self.clean_up_tokenization(out_string) + out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string) + return out_string[1:] + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f'Vocabulary path ({save_directory}) should be a directory') + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, 'wb') as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if self.add_bos_token: + bos_token_ids = [self.bos_token_id] + else: + bos_token_ids = [] + + output = bos_token_ids + token_ids_0 + + if token_ids_1 is not None: + output = output + token_ids_1 + + if self.add_eos_token: + output = output + [self.eos_token_id] + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] diff --git a/internvl/model/internlm2/tokenization_internlm2_fast.py b/internvl/model/internlm2/tokenization_internlm2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0fccbd0f1d029d79e19821f2edcb01b594537c --- /dev/null +++ b/internvl/model/internlm2/tokenization_internlm2_fast.py @@ -0,0 +1,211 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization Fast class for InternLM.""" +import os +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple + +from tokenizers import Tokenizer, decoders, normalizers, processors +from tokenizers.models import BPE +from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS, + SentencePieceExtractor, + SpmConverter) +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.utils import logging + +from .tokenization_internlm2 import InternLM2Tokenizer + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'} + + +# Modified from transformers.convert_slow_tokenizer.LlamaConverter +class InternLM2Converter(SpmConverter): + handle_byte_fallback = True + + def vocab(self, proto): + vocab = [ + ('', 0.0), + ('', 0.0), + ('', 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + return vocab + + def unk_id(self, proto): + unk_id = 0 + return unk_id + + def decoder(self, replacement, add_prefix_space): + return decoders.Sequence( + [ + decoders.Replace('▁', ' '), + decoders.ByteFallback(), + decoders.Fuse(), + decoders.Strip(content=' ', left=1), + ] + ) + + def tokenizer(self, proto): + model_type = proto.trainer_spec.model_type + vocab_scores = self.vocab(proto) + # special tokens + added_tokens = self.original_tokenizer.added_tokens_decoder + for i in range(len(vocab_scores)): + piece, score = vocab_scores[i] + if i in added_tokens: + vocab_scores[i] = (added_tokens[i].content, score) + if model_type == 1: + raise RuntimeError('InternLM2 is supposed to be a BPE model!') + + elif model_type == 2: + _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores) + bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)} + tokenizer = Tokenizer( + BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True) + ) + tokenizer.add_special_tokens( + [ added_token for index, added_token in added_tokens.items()] + ) + else: + raise Exception( + "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" + ) + + return tokenizer + + def normalizer(self, proto): + normalizers_list = [] + if proto.normalizer_spec.add_dummy_prefix: + normalizers_list.append(normalizers.Prepend(prepend='▁')) + normalizers_list.append(normalizers.Replace(pattern=' ', content='▁')) + return normalizers.Sequence(normalizers_list) + + def pre_tokenizer(self, replacement, add_prefix_space): + return None + + +SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter + + +# Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast +class InternLM2TokenizerFast(PreTrainedTokenizerFast): + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = InternLM2Tokenizer + padding_side = 'left' + model_input_names = ['input_ids', 'attention_mask'] + _auto_class = 'AutoTokenizer' + + def __init__( + self, + vocab_file, + unk_token='', + bos_token='', + eos_token='', + pad_token='', + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + decode_with_prefix_space=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + sp_model_kwargs=sp_model_kwargs, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + decode_with_prefix_space=decode_with_prefix_space, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def update_post_processor(self): + """ + Updates the underlying post processor with the current `bos_token` and `eos_token`. + """ + bos = self.bos_token + bos_token_id = self.bos_token_id + if bos is None and self.add_bos_token: + raise ValueError('add_bos_token = True but bos_token = None') + + eos = self.eos_token + eos_token_id = self.eos_token_id + if eos is None and self.add_eos_token: + raise ValueError('add_eos_token = True but eos_token = None') + + single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" + pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' + 'tokenizer.' + ) + + if not os.path.isdir(save_directory): + logger.error(f'Vocabulary path ({save_directory}) should be a directory') + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) diff --git a/internvl/model/internvl_chat/__init__.py b/internvl/model/internvl_chat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d57341208843b8ac0376b4f8cd3e3b9186e3621 --- /dev/null +++ b/internvl/model/internvl_chat/__init__.py @@ -0,0 +1,13 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from .configuration_intern_vit import InternVisionConfig +from .configuration_internvl_chat import InternVLChatConfig +from .modeling_intern_vit import InternVisionModel +from .modeling_internvl_chat import InternVLChatModel + +__all__ = ['InternVisionConfig', 'InternVisionModel', + 'InternVLChatConfig', 'InternVLChatModel'] diff --git a/internvl/model/internvl_chat/__pycache__/__init__.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e81755f6af990c96b4a8d5e470a3d2cb30c90325 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/__init__.cpython-310.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/__init__.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1a1de5ed144acd5da2a4effba7536a27d1371dc Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/__init__.cpython-39.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49a8532ae09c8d389d82984ebaf67ba8a40c90b8 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-310.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..feb5af75a4978ce1222e0ad77cde59d0b08d71a7 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/configuration_intern_vit.cpython-39.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..811ee0d8f78aa3c0c096073de822a954768de549 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-310.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71e4f6c5fe4899fd18e5dfeac3200312d305b721 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/configuration_internvl_chat.cpython-39.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82264473109e907b5f0169ff9e194b6c64308c9b Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-310.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e54913321f8bdf9252a241371c10d1841884c96 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/flash_attention.cpython-39.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2d4e073dd27e2b99ffa8c69b5b02ba931f14a42 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-310.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5f8d763e6f76283c546c72eb0cefb0e156af6b9 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/modeling_intern_vit.cpython-39.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-310.pyc b/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3976792dd9d27d4dde8d4affd3c2413b78b1dc33 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-310.pyc differ diff --git a/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-39.pyc b/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e870ef669b242afc8c54969591bbfaecb2267cb9 Binary files /dev/null and b/internvl/model/internvl_chat/__pycache__/modeling_internvl_chat.cpython-39.pyc differ diff --git a/internvl/model/internvl_chat/configuration_intern_vit.py b/internvl/model/internvl_chat/configuration_intern_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..7e630c456eb9cf350e55bf850c3ff72f445a7e17 --- /dev/null +++ b/internvl/model/internvl_chat/configuration_intern_vit.py @@ -0,0 +1,120 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class InternVisionConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to + instantiate a vision encoder according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_channels (`int`, *optional*, defaults to 3): + Number of color channels in the input images (e.g., 3 for RGB). + patch_size (`int`, *optional*, defaults to 14): + The size (resolution) of each patch. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + qkv_bias (`bool`, *optional*, defaults to `False`): + Whether to add a bias to the queries and values in the self-attention layers. + hidden_size (`int`, *optional*, defaults to 3200): + Dimensionality of the encoder layers and the pooler layer. + num_attention_heads (`int`, *optional*, defaults to 25): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 12800): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + qk_normalization (`bool`, *optional*, defaults to `True`): + Whether to normalize the queries and keys in the self-attention layers. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + use_flash_attn (`bool`, *optional*, defaults to `True`): + Whether to use flash attention mechanism. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-6): + The epsilon used by the layer normalization layers. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + drop_path_rate (`float`, *optional*, defaults to 0.0): + Dropout rate for stochastic depth. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + initializer_factor (`float`, *optional*, defaults to 0.1): + A factor for layer scale. + """ + + model_type = 'intern_vit_6b' + + def __init__( + self, + num_channels=3, + patch_size=14, + image_size=224, + qkv_bias=False, + hidden_size=3200, + num_attention_heads=25, + intermediate_size=12800, + qk_normalization=True, + num_hidden_layers=48, + use_flash_attn=True, + hidden_act='gelu', + norm_type='rms_norm', + layer_norm_eps=1e-6, + dropout=0.0, + drop_path_rate=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=0.1, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.dropout = dropout + self.drop_path_rate = drop_path_rate + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.initializer_range = initializer_range + self.initializer_factor = initializer_factor + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.norm_type = norm_type + self.qkv_bias = qkv_bias + self.qk_normalization = qk_normalization + self.use_flash_attn = use_flash_attn + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if 'vision_config' in config_dict: + config_dict = config_dict['vision_config'] + + if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' + ) + + return cls.from_dict(config_dict, **kwargs) diff --git a/internvl/model/internvl_chat/configuration_internvl_chat.py b/internvl/model/internvl_chat/configuration_internvl_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..eb0e502c8ed44a2680693fc6e5ed1f53b01f3730 --- /dev/null +++ b/internvl/model/internvl_chat/configuration_internvl_chat.py @@ -0,0 +1,93 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import copy + +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +from .configuration_intern_vit import InternVisionConfig + +logger = logging.get_logger(__name__) + + +class InternVLChatConfig(PretrainedConfig): + model_type = 'internvl_chat' + is_composition = True + + def __init__( + self, + vision_config=None, + llm_config=None, + use_backbone_lora=0, + use_llm_lora=0, + pad2square=False, + select_layer=-1, + force_image_size=None, + hidden_size=2048, + downsample_ratio=0.5, + template=None, + dynamic_image_size=False, + use_thumbnail=False, + ps_version='v1', + min_dynamic_patch=1, + max_dynamic_patch=6, + **kwargs): + super().__init__(**kwargs) + # import pdb; pdb.set_trace() + if vision_config is None: + vision_config = {} + logger.info('vision_config is None. Initializing the InternVisionConfig with default values.') + + self.vision_config = InternVisionConfig(**vision_config) + self.llm_config = None + + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.pad2square = pad2square + self.select_layer = select_layer + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.ps_version = ps_version # pixel shuffle version + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + + self.hidden_size = hidden_size + self.tie_word_embeddings = False + + logger.info(f'vision_select_layer: {self.select_layer}') + logger.info(f'ps_version: {self.ps_version}') + logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') + logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output['vision_config'] = self.vision_config.to_dict() + output['model_type'] = self.__class__.model_type + output['use_backbone_lora'] = self.use_backbone_lora + output['use_llm_lora'] = self.use_llm_lora + output['pad2square'] = self.pad2square + output['select_layer'] = self.select_layer + output['force_image_size'] = self.force_image_size + output['downsample_ratio'] = self.downsample_ratio + output['template'] = self.template + output['dynamic_image_size'] = self.dynamic_image_size + output['use_thumbnail'] = self.use_thumbnail + output['ps_version'] = self.ps_version + output['min_dynamic_patch'] = self.min_dynamic_patch + output['max_dynamic_patch'] = self.max_dynamic_patch + + return output diff --git a/internvl/model/internvl_chat/flash_attention.py b/internvl/model/internvl_chat/flash_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..7cda9bfadd290da35bdd04cccd51725e2d419c2f --- /dev/null +++ b/internvl/model/internvl_chat/flash_attention.py @@ -0,0 +1,76 @@ +# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py +import torch +import torch.nn as nn +from einops import rearrange + +try: # v1 + from flash_attn.flash_attn_interface import \ + flash_attn_unpadded_qkvpacked_func +except: # v2 + from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func + +from flash_attn.bert_padding import pad_input, unpad_input + + +class FlashAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): + super().__init__() + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, + max_s=None, need_weights=False): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None + if unpadded: (nnz, 3, h, d) + key_padding_mask: a bool tensor of shape (B, S) + """ + assert not need_weights + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + + if cu_seqlens is None: + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=qkv.device) + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_attn_unpadded_qkvpacked_func( + x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, batch_size, seqlen), + 'b s (h d) -> b s h d', h=nheads) + else: + assert max_s is not None + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + + return output, None diff --git a/internvl/model/internvl_chat/modeling_intern_vit.py b/internvl/model/internvl_chat/modeling_intern_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..3b98fe88c188f602b370b5f1a55f0bced38eb489 --- /dev/null +++ b/internvl/model/internvl_chat/modeling_intern_vit.py @@ -0,0 +1,364 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +import numpy as np +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from timm.models.layers import DropPath +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import (BaseModelOutput, + BaseModelOutputWithPooling) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +from .configuration_intern_vit import InternVisionConfig + +try: + from .flash_attention import FlashAttention + has_flash_attn = True +except: + print('FlashAttention is not installed.') + has_flash_attn = False + +logger = logging.get_logger(__name__) + + +class InternRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +try: + from apex.normalization import FusedRMSNorm + + InternRMSNorm = FusedRMSNorm # noqa + + logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm') +except ImportError: + # using the normal InternRMSNorm + pass +except Exception: + logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm') + pass + + +NORM2FN = { + 'rms_norm': InternRMSNorm, + 'layer_norm': nn.LayerNorm, +} + + +class InternVisionEmbeddings(nn.Module): + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter( + torch.randn(1, 1, self.embed_dim), + ) + + self.patch_embedding = nn.Conv2d( + in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) + + def _get_pos_embed(self, pos_embed, H, W): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \ + reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype) + return pos_embed + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = torch.cat([ + self.position_embedding[:, :1, :], + self._get_pos_embed(self.position_embedding[:, 1:, :], height, width) + ], dim=1) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_flash_attn = config.use_flash_attn and has_flash_attn + if config.use_flash_attn and not has_flash_attn: + print('Warning: Flash Attention is not available, use_flash_attn is set to False.') + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).' + ) + + self.scale = self.head_dim ** -0.5 + self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) + self.attn_drop = nn.Dropout(config.attention_dropout) + self.proj_drop = nn.Dropout(config.dropout) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) + + if self.use_flash_attn: + self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) + self.proj = nn.Linear(self.embed_dim, self.embed_dim) + + def _naive_attn(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + if self.qk_normalization: + B_, H_, N_, D_ = q.shape + q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + + attn = ((q * self.scale) @ k.transpose(-2, -1)) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _flash_attn(self, x, key_padding_mask=None, need_weights=False): + qkv = self.qkv(x) + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) + + if self.qk_normalization: + q, k, v = qkv.unbind(2) + q = self.q_norm(q.flatten(-2, -1)).view(q.shape) + k = self.k_norm(k.flatten(-2, -1)).view(k.shape) + qkv = torch.stack([q, k, v], dim=2) + + context, _ = self.inner_attn( + qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False + ) + outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) + outs = self.proj_drop(outs) + return outs + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states) + return x + + +class InternMLP(nn.Module): + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + self.act = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + def __init__(self, config: InternVisionConfig, drop_path_rate: float): + super().__init__() + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = InternAttention(config) + self.mlp = InternMLP(config) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) + self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: + """ + Args: + hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` + """ + hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1) + + hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2) + + return hidden_states + + +class InternVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`InternEncoderLayer`]. + + Args: + config (`InternConfig`): + The corresponding vision configuration for the `InternEncoder`. + """ + + def __init__(self, config: InternVisionConfig): + super().__init__() + self.config = config + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) + self.gradient_checkpointing = True + + def forward( + self, + inputs_embeds, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Embedded representation of the inputs. Should be float, not int tokens. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + hidden_states = inputs_embeds + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint( + encoder_layer, + hidden_states) + else: + layer_outputs = encoder_layer( + hidden_states, + ) + hidden_states = layer_outputs + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states + ) + + +class InternVisionModel(PreTrainedModel): + main_input_name = 'pixel_values' + _supports_flash_attn_2 = True + config_class = InternVisionConfig + _no_split_modules = ['InternVisionEncoderLayer'] + + def __init__(self, config: InternVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = InternVisionEmbeddings(config) + self.encoder = InternVisionEncoder(config) + + def resize_pos_embeddings(self, old_size, new_size, patch_size): + pos_emb = self.embeddings.position_embedding + _, num_positions, embed_dim = pos_emb.shape + cls_emb = pos_emb[:, :1, :] + pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2) + pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False) + pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1) + pos_emb = torch.cat([cls_emb, pos_emb], dim=1) + self.embeddings.position_embedding = nn.Parameter(pos_emb) + self.embeddings.image_size = new_size + logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size)) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_embeds: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if pixel_values is None and pixel_embeds is None: + raise ValueError('You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + else: + if len(pixel_values.shape) == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = encoder_outputs.last_hidden_state + pooled_output = last_hidden_state[:, 0, :] + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) diff --git a/internvl/model/internvl_chat/modeling_internvl_chat.py b/internvl/model/internvl_chat/modeling_internvl_chat.py new file mode 100644 index 0000000000000000000000000000000000000000..618f0ca0cfe00edac96198930dfc11cbf2cce2cd --- /dev/null +++ b/internvl/model/internvl_chat/modeling_internvl_chat.py @@ -0,0 +1,424 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import numpy as np +import warnings +from typing import Any, List, Optional, Tuple, Union +import torch.nn.functional as F +import torch.distributed as dist +import torch.utils.checkpoint +import transformers +from internvl.conversation import get_conv_template +from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM +from internvl.model.phi3.modeling_phi3 import Phi3ForCausalLM +from peft import LoraConfig, get_peft_model +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, + LlamaTokenizer, Qwen2ForCausalLM) +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput, logging + +from .configuration_internvl_chat import InternVLChatConfig +from .modeling_intern_vit import InternVisionModel + +logger = logging.get_logger(__name__) + + +def version_cmp(v1, v2, op='eq'): + import operator + + from packaging import version + op_func = getattr(operator, op) + return op_func(version.parse(v1), version.parse(v2)) + + +class InternVLChatModel(PreTrainedModel): + config_class = InternVLChatConfig + main_input_name = 'pixel_values' + base_model_prefix = '' + _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'InternLM2DecoderLayer', + 'Phi3DecoderLayer', 'Qwen2DecoderLayer'] + _supports_flash_attn_2 = True + supports_gradient_checkpointing = True + + def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None): + super().__init__(config) + + assert version_cmp(transformers.__version__, '4.37.0', 'ge') + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.select_layer = config.select_layer + self.template = config.template + self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + # self.llm_arch_name = config.llm_config.architectures[0] + + logger.info(f'num_image_token: {self.num_image_token}') + logger.info(f'ps_version: {self.ps_version}') + if vision_model is not None: + self.vision_model = vision_model + else: + self.vision_model = InternVisionModel(config.vision_config) + self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, 2).to(torch.bfloat16) + + # if language_model is not None: + # self.language_model = language_model + # else: + # if config.llm_config.architectures[0] == 'LlamaForCausalLM': + # self.language_model = LlamaForCausalLM(config.llm_config) + # elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM': + # self.language_model = InternLM2ForCausalLM(config.llm_config) + # elif config.llm_config.architectures[0] == 'Phi3ForCausalLM': + # self.language_model = Phi3ForCausalLM(config.llm_config) + # elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM': + # self.language_model = Qwen2ForCausalLM(config.llm_config) + # else: + # raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') + + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.hidden_size + + self.ocr_mlp = nn.Sequential( + nn.LayerNorm(vit_hidden_size), + nn.Linear(vit_hidden_size, llm_hidden_size), + nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size) + ) + if config.train_stage > 1: + self.mlp1 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), + nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size) + ) + + self.img_context_token_id = None + self.conv_template = get_conv_template(self.template) + if hasattr(config, 'system_message'): + self.system_message = config.system_message + else: + self.system_message = self.conv_template.system_message + self.num_samples = 0 + + if config.use_backbone_lora: + self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora) + + if config.use_llm_lora: + self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora) + + init_tau=np.log(10) + init_b=-2.71 + self.t_prime = nn.Parameter(torch.ones([]) * init_tau) + self.b = nn.Parameter(torch.ones([]) * init_b) + self.kb = False + self.upsample = nn.Sequential( + nn.ConvTranspose2d( + in_channels=1024, + out_channels=512, + kernel_size=4, + stride=2, + padding=1, + bias=False + ), + # nn.BatchNorm2d(512), + nn.SyncBatchNorm(512), + # 第二层反卷积:进一步上采样到目标分辨率 + nn.ConvTranspose2d( + in_channels=512, + out_channels=1024, + kernel_size=4, + stride=2, + padding=1, + bias=False + ), + # nn.BatchNorm2d(1024), + nn.SyncBatchNorm(1024), + ) + + def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): + lora_config = LoraConfig( + r=r, + target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'], + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + self.vision_model = get_peft_model(self.vision_model, lora_config) + self.vision_model.print_trainable_parameters() + + def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): + # Determine the target modules based on the architecture of the language model + if self.llm_arch_name == 'InternLM2ForCausalLM': + target_modules = ['attention.wqkv', 'attention.wo', 'feed_forward.w1', 'feed_forward.w2', 'feed_forward.w3'] + elif self.llm_arch_name == 'Phi3ForCausalLM': + target_modules = ['mlp.down_proj', 'mlp.gate_up_proj', 'self_attn.o_proj', 'self_attn.qkv_proj'] + elif self.llm_arch_name in ['Qwen2ForCausalLM', 'LlamaForCausalLM']: + target_modules = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', + 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'] + else: + raise NotImplemented + lora_config = LoraConfig( + r=r, + target_modules=target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + task_type='CAUSAL_LM' + ) + self.language_model = get_peft_model(self.language_model, lora_config) + self.language_model.enable_input_require_grads() + self.language_model.print_trainable_parameters() + + def forward_tokenocr(self, + pixel_values: torch.FloatTensor)-> Union[Tuple, CausalLMOutputWithPast]: + vit_embeds = self.extract_feature_custom(pixel_values) #(vit_batch_size, 16*16, 2048) + # vit_embeds = self.extract_feature_custom_no_upsample(pixel_values) #(vit_batch_size, 16*16, 2048) + return vit_embeds, None + + + + def pixel_unshuffle(self, x, scale_factor=4): + h = w = int(x.shape[1] ** 0.5) + n, l, c = x.size() + x = x.reshape(n, h, w, c) + x = x.repeat_interleave(scale_factor, dim=1).repeat_interleave(scale_factor, dim=2) + return x + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " + 'which results in a transposed image.') + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + if self.select_layer == -1: + vit_embeds = self.vision_model( + pixel_values=pixel_values, + output_hidden_states=False, + return_dict=True).last_hidden_state + else: + vit_embeds = self.vision_model( + pixel_values=pixel_values, + output_hidden_states=True, + return_dict=True).hidden_states[self.select_layer] + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def extract_feature_custom(self, pixel_values): + if self.select_layer == -1: + vit_embeds = self.vision_model( + pixel_values=pixel_values, + output_hidden_states=False, + return_dict=True).last_hidden_state + else: + vit_embeds = self.vision_model( + pixel_values=pixel_values, + output_hidden_states=True, + return_dict=True).hidden_states[self.select_layer] + vit_embeds = vit_embeds[:, 1:, :] # (52, 1025, 1024) + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.permute(0,2,1).reshape(vit_embeds.shape[0], -1, h, w) + vit_embeds = self.upsample(vit_embeds) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-2] * vit_embeds.shape[-1]) + vit_embeds = self.ocr_mlp(vit_embeds.permute(0, 2, 1)) + return vit_embeds + + def extract_feature_custom_no_upsample(self, pixel_values): + if self.select_layer == -1: + vit_embeds = self.vision_model( + pixel_values=pixel_values, + output_hidden_states=False, + return_dict=True).last_hidden_state + else: + vit_embeds = self.vision_model( + pixel_values=pixel_values, + output_hidden_states=True, + return_dict=True).hidden_states[self.select_layer] + vit_embeds = vit_embeds[:, 1:, :] # (52, 1025, 1024) + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = self.ocr_mlp(vit_embeds) + # vit_embeds = self.pixel_unshuffle(vit_embeds) + # vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + return vit_embeds + + def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None, + history=None, return_history=False, IMG_START_TOKEN='', IMG_END_TOKEN='', + IMG_CONTEXT_TOKEN='', verbose=False, image_counts=None): + if history is not None or return_history: + print('Now multi-turn chat is not supported in batch_chat.') + raise NotImplementedError + + if image_counts is not None: + num_patches_list = image_counts + print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.') + + img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.img_context_token_id = img_context_token_id + + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + print(f'dynamic ViT batch size: {image_bs}') + + queries = [] + for idx, num_patches in enumerate(num_patches_list): + question = questions[idx] + if pixel_values is not None and '' not in question: + question = '\n' + question + template = get_conv_template(self.template) + template.system_message = self.system_message + template.append_message(template.roles[0], question) + template.append_message(template.roles[1], None) + query = template.get_prompt() + + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace('', image_tokens, 1) + queries.append(query) + + tokenizer.padding_side = 'left' + model_inputs = tokenizer(queries, return_tensors='pt', padding=True) + input_ids = model_inputs['input_ids'].cuda() + attention_mask = model_inputs['attention_mask'].cuda() + eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip()) + generation_config['eos_token_id'] = eos_token_id + generation_output = self.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + **generation_config + ) + responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) + responses = [response.split(template.sep.strip())[0].strip() for response in responses] + return responses + + def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, + num_patches_list=None, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN='', + verbose=False): + + if history is None and pixel_values is not None and '' not in question: + question = '\n' + question + + if num_patches_list is None: + num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] + assert pixel_values is None or len(pixel_values) == sum(num_patches_list) + + img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.img_context_token_id = img_context_token_id + + template = get_conv_template(self.template) + template.system_message = self.system_message + eos_token_id = tokenizer.convert_tokens_to_ids(template.sep.strip()) + + history = [] if history is None else history + for (old_question, old_answer) in history: + template.append_message(template.roles[0], old_question) + template.append_message(template.roles[1], old_answer) + template.append_message(template.roles[0], question) + template.append_message(template.roles[1], None) + query = template.get_prompt() + + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + print(f'dynamic ViT batch size: {image_bs}') + + for num_patches in num_patches_list: + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace('', image_tokens, 1) + + model_inputs = tokenizer(query, return_tensors='pt') + input_ids = model_inputs['input_ids'].cuda() + attention_mask = model_inputs['attention_mask'].cuda() + generation_config['eos_token_id'] = eos_token_id + generation_output = self.generate( + pixel_values=pixel_values, + input_ids=input_ids, + attention_mask=attention_mask, + **generation_config + ) + response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] + response = response.split(template.sep.strip())[0].strip() + history.append((question, response)) + if return_history: + return response, history + else: + query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') + query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') + if verbose: + print(query_to_print, response) + return response + + @torch.no_grad() + def generate( + self, + pixel_values: Optional[torch.FloatTensor] = None, + input_ids: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + visual_features: Optional[torch.FloatTensor] = None, + generation_config: Optional[GenerationConfig] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **generate_kwargs, + ) -> torch.LongTensor: + + assert self.img_context_token_id is not None + if pixel_values is not None: + if visual_features is not None: + vit_embeds = visual_features + else: + vit_embeds = self.extract_feature(pixel_values) + input_embeds = self.language_model.get_input_embeddings()(input_ids) + B, N, C = input_embeds.shape + input_embeds = input_embeds.reshape(B * N, C) + + input_ids = input_ids.reshape(B * N) + selected = (input_ids == self.img_context_token_id) + assert selected.sum() != 0 + input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) + + input_embeds = input_embeds.reshape(B, N, C) + else: + input_embeds = self.language_model.get_input_embeddings()(input_ids) + + outputs = self.language_model.generate( + inputs_embeds=input_embeds, + attention_mask=attention_mask, + generation_config=generation_config, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=True, + **generate_kwargs, + ) + + return outputs + + @property + def lm_head(self): + return self.language_model.get_output_embeddings() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() diff --git a/internvl/model/phi3/__pycache__/configuration_phi3.cpython-310.pyc b/internvl/model/phi3/__pycache__/configuration_phi3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a08c18584f23108bd92a2b9f30ae5b11561fd1f Binary files /dev/null and b/internvl/model/phi3/__pycache__/configuration_phi3.cpython-310.pyc differ diff --git a/internvl/model/phi3/__pycache__/configuration_phi3.cpython-39.pyc b/internvl/model/phi3/__pycache__/configuration_phi3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69240f7ea2679fc67a9ee3ae76d9ff72fcb62331 Binary files /dev/null and b/internvl/model/phi3/__pycache__/configuration_phi3.cpython-39.pyc differ diff --git a/internvl/model/phi3/__pycache__/modeling_phi3.cpython-310.pyc b/internvl/model/phi3/__pycache__/modeling_phi3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97c5684f72df959c344a8e775412519e295d882e Binary files /dev/null and b/internvl/model/phi3/__pycache__/modeling_phi3.cpython-310.pyc differ diff --git a/internvl/model/phi3/__pycache__/modeling_phi3.cpython-39.pyc b/internvl/model/phi3/__pycache__/modeling_phi3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30e2ed25a0480075cfdbba62c66dbce3419de256 Binary files /dev/null and b/internvl/model/phi3/__pycache__/modeling_phi3.cpython-39.pyc differ diff --git a/internvl/model/phi3/configuration_phi3.py b/internvl/model/phi3/configuration_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..c657051097ebd7655786d74f8ed75635bfc844c4 --- /dev/null +++ b/internvl/model/phi3/configuration_phi3.py @@ -0,0 +1,211 @@ +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License atd +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Phi-3 model configuration""" + + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'microsoft/Phi-3-mini-4k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json', + 'microsoft/Phi-3-mini-128k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json', +} + + +class Phi3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the + [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 32064): + Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Phi3Model`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 32000): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 32000): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import Phi3Model, Phi3Config + + >>> # Initializing a Phi-3 style configuration + >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + + >>> # Initializing a model from the configuration + >>> model = Phi3Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = 'phi3' + keys_to_ignore_at_inference = ['past_key_values'] + + def __init__( + self, + vocab_size=32064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act='silu', + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + bos_token_id=1, + eos_token_id=32000, + pad_token_id=32000, + sliding_window=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + '`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, ' + f'got {self.rope_scaling}' + ) + rope_scaling_type = self.rope_scaling.get('type', None) + rope_scaling_short_factor = self.rope_scaling.get('short_factor', None) + rope_scaling_long_factor = self.rope_scaling.get('long_factor', None) + if rope_scaling_type is None or rope_scaling_type not in ['su', 'yarn']: + raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" + ) diff --git a/internvl/model/phi3/modeling_phi3.py b/internvl/model/phi3/modeling_phi3.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb5dc0450c0435262f5d32afa0533ad4161d92d --- /dev/null +++ b/internvl/model/phi3/modeling_phi3.py @@ -0,0 +1,1610 @@ +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" PyTorch Phi-3 model.""" + +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import \ + _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import (BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import (add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, logging, + replace_return_docstrings) + +from .configuration_phi3 import Phi3Config + +logger = logging.get_logger(__name__) + +# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements +# if is_flash_attn_2_available(): +_flash_supports_window_size = False +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa + unpad_input) + + _flash_supports_window_size = 'window_size' in list(inspect.signature(flash_attn_func).parameters) + has_flash_attn = True +except ImportError as error: + logger.warning( + f'`flash-attention` package not found, consider installing for better performance: {error}.' + ) + if not _flash_supports_window_size: + logger.warning( + "Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`." + ) + has_flash_attn = False + +_CHECKPOINT_FOR_DOC = 'microsoft/Phi-3-mini-4k-instruct' +_CONFIG_FOR_DOC = 'Phi3Config' + +PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [ + 'microsoft/Phi-3-mini-4k-instruct', + 'microsoft/Phi-3-mini-128k-instruct', + # See all Phi-3 models at https://huggingface.co/models?filter=Phi-3 +] + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 +class Phi3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +class Phi3RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer('inv_freq', None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling['short_factor'] + self.long_factor = config.rope_scaling['long_factor'] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling['short_factor'] + self.long_factor = config.rope_scaling['long_factor'] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = 0.1 * math.log(scale) + 1.0 + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Phi3MLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Phi3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will ' + 'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` ' + 'when creating this class.' + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.original_max_position_embeddings = config.original_max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' + f' and `num_heads`: {self.num_heads}).' + ) + + op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.rope_scaling is None: + self.rotary_emb = Phi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling['type'] + if scaling_type == 'su': + self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config) + elif scaling_type == 'yarn': + self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config) + else: + raise ValueError(f'Unknown RoPE scaling type {scaling_type}') + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once('You are not running the flash-attention implementation, expect numerical differences.') + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' + 'with a layer index.' + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' + f' {attn_weights.size()}' + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}' + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Phi3FlashAttention2(Phi3Attention): + """ + Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Phi3FlashAttention2 attention does not support output_attentions + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library." + ) + raise ValueError('The current flash attention version does not support sliding window attention.') + + output_attentions = False + + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' + 'with a layer index.' + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, 'sliding_window', None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, 'sliding_window', None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f'past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got' + f' {past_key.shape}' + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.qkv_proj.weight.dtype + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, this might be related to' + f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.' + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=attn_dropout, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + # Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 +# TODO @Arthur no longer copied from LLama after static cache +class Phi3SdpaAttention(Phi3Attention): + """ + Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Phi3Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + 'Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, ' + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == 'cuda' and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +PHI3_ATTENTION_CLASSES = { + 'eager': Phi3Attention, + 'flash_attention_2': Phi3FlashAttention2, + 'sdpa': Phi3SdpaAttention, +} + + +class Phi3DecoderLayer(nn.Module): + def __init__(self, config: Phi3Config, layer_idx: int): + super().__init__() + + self.config = config + self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.mlp = Phi3MLP(config) + self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if 'padding_mask' in kwargs: + warnings.warn( + 'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`' + ) + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +PHI3_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Phi3Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.', + PHI3_START_DOCSTRING, +) +class Phi3PreTrainedModel(PreTrainedModel): + config_class = Phi3Config + base_model_prefix = 'model' + supports_gradient_checkpointing = True + _no_split_modules = ['Phi3DecoderLayer'] + _skip_keys_device_placement = 'past_key_values' + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + _version = '0.0.5' + + def __init__(self, config: Phi3Config): + if not has_flash_attn: + config._attn_implementation = 'eager' + print('Warning: Flash attention is not available, using eager attention instead.') + super().__init__(config) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHI3_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + 'The bare Phi-3 model outputting raw hidden-states without any specific head on top.', + PHI3_START_DOCSTRING, +) +class Phi3Model(Phi3PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] + + Args: + config: Phi3Config + """ + + def __init__(self, config: Phi3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.layers = nn.ModuleList( + [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + + self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time') + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError('You have to specify either input_ids or inputs_embeds') + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == 'flash_attention_2' and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + ' this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to ' + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == 'flash_attention_2': + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Phi3ForCausalLM(Phi3PreTrainedModel): + _tied_weights_keys = ['lm_head.weight'] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 + def __init__(self, config): + super().__init__(config) + self.model = Phi3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get('position_ids', None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if (inputs_embeds is not None and past_key_values is None) or (inputs_embeds is not None and len(past_key_values) == 0): + model_inputs = {'inputs_embeds': inputs_embeds} + else: + model_inputs = {'input_ids': input_ids} + + model_inputs.update( + { + 'position_ids': position_ids, + 'past_key_values': past_key_values, + 'use_cache': kwargs.get('use_cache'), + 'attention_mask': attention_mask, + } + ) + return model_inputs + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The [`Phi3Model`] with a sequence classification head on top (linear layer). + + [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHI3_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs +class Phi3ForSequenceClassification(Phi3PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Phi3Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.') + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = 'regression' + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = 'single_label_classification' + else: + self.config.problem_type = 'multi_label_classification' + + if self.config.problem_type == 'regression': + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == 'single_label_classification': + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == 'multi_label_classification': + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +@add_start_docstrings( + """ + [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + PHI3_START_DOCSTRING, +) +# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs +class Phi3ForTokenClassification(Phi3PreTrainedModel): + def __init__(self, config: Phi3Config): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = Phi3Model(config) + if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, 'hidden_dropout') and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = model_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) + ) + + if not return_dict: + output = (logits,) + model_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) diff --git a/internvl/patch/__init__.py b/internvl/patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..63f84fe933ba5e62ac4329f1b2db164d22b9df12 --- /dev/null +++ b/internvl/patch/__init__.py @@ -0,0 +1,34 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from .internlm2_packed_training_patch import replace_internlm2_attention_class +from .internvit_liger_monkey_patch import apply_liger_kernel_to_internvit +from .llama2_flash_attn_monkey_patch import replace_llama2_attn_with_flash_attn +from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn +from .llama_packed_training_patch import replace_llama_attention_class +from .llama_rmsnorm_monkey_patch import \ + replace_llama_rmsnorm_with_fused_rmsnorm +from .pad_data_collator import (concat_pad_data_collator, + dpo_concat_pad_data_collator, + pad_data_collator) +from .phi3_packed_training_patch import replace_phi3_attention_class +from .qwen2_packed_training_patch import replace_qwen2_attention_class +from .train_dataloader_patch import replace_train_dataloader +from .train_sampler_patch import replace_train_sampler + +__all__ = ['replace_llama_attn_with_flash_attn', + 'replace_llama_rmsnorm_with_fused_rmsnorm', + 'replace_llama2_attn_with_flash_attn', + 'replace_train_sampler', + 'replace_train_dataloader', + 'replace_internlm2_attention_class', + 'replace_qwen2_attention_class', + 'replace_phi3_attention_class', + 'replace_llama_attention_class', + 'pad_data_collator', + 'dpo_concat_pad_data_collator', + 'concat_pad_data_collator', + 'apply_liger_kernel_to_internvit'] diff --git a/internvl/patch/__pycache__/__init__.cpython-39.pyc b/internvl/patch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..719c9fd406cc93c434d720d4f1457c0a89f50815 Binary files /dev/null and b/internvl/patch/__pycache__/__init__.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/internlm2_packed_training_patch.cpython-39.pyc b/internvl/patch/__pycache__/internlm2_packed_training_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce626181c4880ac68e765c49c27d58cb7841c1f6 Binary files /dev/null and b/internvl/patch/__pycache__/internlm2_packed_training_patch.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/internvit_liger_monkey_patch.cpython-39.pyc b/internvl/patch/__pycache__/internvit_liger_monkey_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c66f6ccde4b246a8b791b1266b41c866860807f1 Binary files /dev/null and b/internvl/patch/__pycache__/internvit_liger_monkey_patch.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/llama2_flash_attn_monkey_patch.cpython-39.pyc b/internvl/patch/__pycache__/llama2_flash_attn_monkey_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e6d17d0ca475b1474d0e7b5f7d59b5003d8b407 Binary files /dev/null and b/internvl/patch/__pycache__/llama2_flash_attn_monkey_patch.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc b/internvl/patch/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d27cda75494b01780dc1caf565fb902500953806 Binary files /dev/null and b/internvl/patch/__pycache__/llama_flash_attn_monkey_patch.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/llama_packed_training_patch.cpython-39.pyc b/internvl/patch/__pycache__/llama_packed_training_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f3e97973d10abf40cbf8fff67ccdd7a5e029253 Binary files /dev/null and b/internvl/patch/__pycache__/llama_packed_training_patch.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/llama_rmsnorm_monkey_patch.cpython-39.pyc b/internvl/patch/__pycache__/llama_rmsnorm_monkey_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18aef4261936aa693ad1c3ceba294a36d13ef51e Binary files /dev/null and b/internvl/patch/__pycache__/llama_rmsnorm_monkey_patch.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/pad_data_collator.cpython-39.pyc b/internvl/patch/__pycache__/pad_data_collator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9ef9db1f0117bd90a1f145a5da27278b52ec236 Binary files /dev/null and b/internvl/patch/__pycache__/pad_data_collator.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/phi3_packed_training_patch.cpython-39.pyc b/internvl/patch/__pycache__/phi3_packed_training_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d88d3bbbbdb30bbec35cf44fd5f9418b6dd0c37c Binary files /dev/null and b/internvl/patch/__pycache__/phi3_packed_training_patch.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/qwen2_packed_training_patch.cpython-39.pyc b/internvl/patch/__pycache__/qwen2_packed_training_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec200a6c9a2dba76069834e18d91195b20fa83cd Binary files /dev/null and b/internvl/patch/__pycache__/qwen2_packed_training_patch.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/train_dataloader_patch.cpython-39.pyc b/internvl/patch/__pycache__/train_dataloader_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffeb0a267ace08d24f1a6ca8fa9afcb3b26be50b Binary files /dev/null and b/internvl/patch/__pycache__/train_dataloader_patch.cpython-39.pyc differ diff --git a/internvl/patch/__pycache__/train_sampler_patch.cpython-39.pyc b/internvl/patch/__pycache__/train_sampler_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f500d392e711a0c32b90d128d969b92eb801903 Binary files /dev/null and b/internvl/patch/__pycache__/train_sampler_patch.cpython-39.pyc differ diff --git a/internvl/patch/internlm2_packed_training_patch.py b/internvl/patch/internlm2_packed_training_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..c0805e63876f23645c1802b163bbdb844de5fe27 --- /dev/null +++ b/internvl/patch/internlm2_packed_training_patch.py @@ -0,0 +1,74 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import torch +from flash_attn.flash_attn_interface import flash_attn_varlen_func +from internvl.model.internlm2.modeling_internlm2 import ( + INTERNLM2_ATTENTION_CLASSES, InternLM2FlashAttention2, + apply_rotary_pos_emb) + + +# Modified from internvl.model.internlm2.modeling_internlm2.InternLM2FlashAttention2 +class InternLM2FlashAttention2ForPackedTraining(InternLM2FlashAttention2): + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 + query_states = query_states.squeeze(0) + key_states = key_states.squeeze(0) + value_states = value_states.squeeze(0) + cu_seqlens = attention_mask.squeeze(0) + + with torch.no_grad(): + max_seqlen = max([ + cu_seqlens[idx+1] - cu_seqlens[idx] + for idx in range(cu_seqlens.size(0) - 1) + ]).item() + + # Contains at least one padding token in the sequence + causal = self.is_causal and query_length != 1 + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + query_states = query_states.unsqueeze(0) + key_states = key_states.unsqueeze(0) + value_states = value_states.unsqueeze(0) + return attn_output + + +def replace_internlm2_attention_class(): + INTERNLM2_ATTENTION_CLASSES['flash_attention_2'] = InternLM2FlashAttention2ForPackedTraining + print('Replace INTERNLM2_ATTENTION_CLASSES to support packed training!!') diff --git a/internvl/patch/internvit_liger_monkey_patch.py b/internvl/patch/internvit_liger_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..154aab1d1884968530329e08123276ba71faa07f --- /dev/null +++ b/internvl/patch/internvit_liger_monkey_patch.py @@ -0,0 +1,13 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +def apply_liger_kernel_to_internvit() -> None: + from internvl.model.internvl_chat import modeling_intern_vit + from liger_kernel.transformers.layer_norm import LigerLayerNorm + from liger_kernel.transformers.rms_norm import LigerRMSNorm + modeling_intern_vit.NORM2FN['rms_norm'] = LigerRMSNorm + modeling_intern_vit.NORM2FN['layer_norm'] = LigerLayerNorm + print('Liger kernel applied to InternViT') diff --git a/internvl/patch/llama2_flash_attn_monkey_patch.py b/internvl/patch/llama2_flash_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..01091ef2a0c1197b4638a4b8d95214670ac94039 --- /dev/null +++ b/internvl/patch/llama2_flash_attn_monkey_patch.py @@ -0,0 +1,238 @@ +""" +This file is copied from: https://github.com/lm-sys/FastChat +""" + +import warnings +from typing import Optional, Tuple + +import torch +from flash_attn import __version__ as flash_attn_version +from flash_attn.bert_padding import pad_input, unpad_input +from flash_attn.flash_attn_interface import (flash_attn_func, + flash_attn_varlen_kvpacked_func) +from transformers.models.llama.modeling_llama import (LlamaAttention, + LlamaModel, rotate_half) + + +def apply_rotary_pos_emb(q, k, cos_sin, position_ids): + gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] + gather_indices = gather_indices.repeat( + 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] + ) + bsz = gather_indices.shape[0] + cos, sin = ( + torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) + for x in cos_sin + ) + q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) + return q, k + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + warnings.warn( + 'Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.' + ) + + bsz, q_len, _ = hidden_states.size() + kv_heads = getattr(self, 'num_key_value_heads', self.num_heads) + + q, k, v = ( + op(hidden_states).view(bsz, q_len, nh, self.head_dim) + for op, nh in ( + (self.q_proj, self.num_heads), + (self.k_proj, kv_heads), + (self.v_proj, kv_heads), + ) + ) + # shape: (b, s, num_heads, head_dim) + + kv_seq_len = k.shape[1] + past_kv_len = 0 + if past_key_value is not None: + past_kv_len = past_key_value[0].shape[2] + kv_seq_len += past_kv_len + + cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) + + if past_key_value is not None: + assert ( + flash_attn_version >= '2.1.0' + ), 'past_key_value support requires flash-attn >= 2.1.0' + # reuse k, v + k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) + v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) + + past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None + + if attention_mask is None: + output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( + bsz, q_len, -1 + ) + else: + q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) + # We can skip concat and call unpad twice but seems better to call unpad only once. + kv, _, cu_k_lens, max_k = unpad_input( + torch.stack((k, v), dim=2), attention_mask + ) + output_unpad = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_q_lens, + cu_k_lens, + max_s, + max_k, + 0.0, + softmax_scale=None, + causal=True, + ) + output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) + output = pad_input(output_unpad, indices, bsz, q_len) + + return self.o_proj(output), None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as flash attention +# takes a boolean key_padding_mask. Fills in the past kv length for use in forward. +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + if past_key_values_length > 0 and attention_mask is not None: + attention_mask = torch.cat( + ( + torch.full( + (input_shape[0], past_key_values_length), + True, + dtype=attention_mask.dtype, + device=attention_mask.device, + ), + attention_mask, + ), + dim=-1, + ) + + if attention_mask is not None and torch.all(attention_mask): + return None # This uses the faster call when training with full samples + + return attention_mask + + +def replace_llama2_attn_with_flash_attn(): + cuda_major, cuda_minor = torch.cuda.get_device_capability() + if cuda_major < 8: + warnings.warn( + 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.' + 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593' + ) + + LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + LlamaAttention.forward = forward + + +def test(): + from fastchat.train.llama_flash_attn_monkey_patch import \ + forward as fastchat_forward + from transformers.models.llama.configuration_llama import LlamaConfig + + config = LlamaConfig( + hidden_size=1024, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=8, + max_position_embeddings=16, + ) + device = torch.device('cuda') + model = LlamaModel(config) + attn = LlamaAttention(config).to(device).half() + bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings + position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view( + -1, seqlen + ) + + mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) + for i in range(4): + hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) + if i: + mask[0, -i:] = False + mask[1, :i] = False + + lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0) + ref, _, _ = attn.forward( + hidden, attention_mask=lmask, position_ids=position_ids + ) + + fast, _, _ = fastchat_forward( + attn, hidden, attention_mask=mask, position_ids=position_ids + ) + + lmask = _prepare_decoder_attention_mask( + model, mask, hidden.shape[:2], hidden, 0 + ) + test, _, _ = forward( + attn, hidden, attention_mask=lmask, position_ids=position_ids + ) + + print(f'Mean(abs(ref)) = {torch.mean(torch.abs(ref))}') + print(f'Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}') + print(f'Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}') + print(f'Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}') + print(f'allclose(fast, test) = {torch.allclose(fast, test)}') + + with torch.no_grad(): + # Also check that past_kv is handled properly + hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) + part_len = seqlen // 4 + assert part_len * 4 == seqlen + mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) + mask[0, -2:] = False + lmask = _prepare_decoder_attention_mask( + model, mask, hidden.shape[:2], hidden, 0 + ) + oneshot, _, _ = forward( + attn, hidden, attention_mask=lmask, position_ids=position_ids + ) + parts = [] + past_kv, past_kv_len = None, 0 + for i in range(4): + start = part_len * i + end = start + part_len + hidden_part = hidden[:, start:end, ...] + lmask = _prepare_decoder_attention_mask( + model, + mask[:, start:end], + hidden_part.shape[:2], + hidden_part, + past_kv_len, + ) + part, _, past_kv = forward( + attn, + hidden_part.clone(), + attention_mask=lmask, + position_ids=position_ids[:, start:end], + past_key_value=past_kv, + use_cache=True, + ) + parts.append(part) + past_kv_len = past_kv[0].shape[2] + + print( + f'allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}' + ) + print( + f'allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}' + ) + + +if __name__ == '__main__': + test() diff --git a/internvl/patch/llama_flash_attn_monkey_patch.py b/internvl/patch/llama_flash_attn_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..d8b01d5811ad7860cd00786d83ecc5aafaf82aa4 --- /dev/null +++ b/internvl/patch/llama_flash_attn_monkey_patch.py @@ -0,0 +1,222 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import math +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import transformers +from torch import nn +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + from einops import rearrange + try: # v1 + from flash_attn.flash_attn_interface import \ + flash_attn_unpadded_qkvpacked_func + except: # v2 + from flash_attn.flash_attn_interface import \ + flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func + from flash_attn.bert_padding import pad_input, unpad_input + + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + # [bsz, q_len, nh, hd] + # [bsz, nh, q_len, hd] + + kv_seq_len = key_states.shape[-2] + assert past_key_value is None, 'past_key_value is not supported' + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + # [bsz, nh, t, hd] + assert not output_attentions, 'output_attentions is not supported' + assert not use_cache, 'use_cache is not supported' + + # Flash attention codes from + # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py + + # transform the data into the format required by flash attention + qkv = torch.stack( + [query_states, key_states, value_states], dim=2 + ) # [bsz, nh, 3, q_len, hd] + qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] + # We have disabled _prepare_decoder_attention_mask in LlamaModel + # the attention_mask should be the same as the key_padding_mask + key_padding_mask = attention_mask + + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = q_len + cu_q_lens = torch.arange( + 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device + ) + output = flash_attn_unpadded_qkvpacked_func( + qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange(output, '(b s) ... -> b s ...', b=bsz) + else: + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) + x_unpad = rearrange( + x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads + ) + output_unpad = flash_attn_unpadded_qkvpacked_func( + x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True + ) + output = rearrange( + pad_input( + rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, bsz, q_len + ), + 'b s (h d) -> b s h d', + h=nheads, + ) + return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + # [bsz, seq_len] + return attention_mask + + +def forward_2( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + assert not output_attentions, 'output_attentions is not supported' + assert not use_cache, 'use_cache is not supported' + assert past_key_value is None, 'past_key_value is not supported' + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + if self.training: + attn_output = F.scaled_dot_product_attention( + query_states, key_states, value_states, dropout_p=0.0, is_causal=True + ) + attn_weights = None + else: + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is' + f' {attn_weights.size()}' + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max( + attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' + f' {attn_output.size()}' + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def replace_llama_attn_with_flash_attn(): + if hasattr(F, 'scaled_dot_product_attention'): + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_2 + else: + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( + _prepare_decoder_attention_mask + ) + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/internvl/patch/llama_packed_training_patch.py b/internvl/patch/llama_packed_training_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..6c915554919ba72f53544744c53f59d9e131c41c --- /dev/null +++ b/internvl/patch/llama_packed_training_patch.py @@ -0,0 +1,106 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import torch +from flash_attn.flash_attn_interface import flash_attn_varlen_func +from transformers.models.llama.modeling_llama import (LLAMA_ATTENTION_CLASSES, + LlamaFlashAttention2) + + +# Modified from transformers.models.llama.modeling_llama.LlamaFlashAttention2 +class LlamaFlashAttention2ForPackedTraining(LlamaFlashAttention2): + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 + query_states = query_states.squeeze(0) + key_states = key_states.squeeze(0) + value_states = value_states.squeeze(0) + cu_seqlens = attention_mask.squeeze(0) + + with torch.no_grad(): + max_seqlen = max([ + cu_seqlens[idx+1] - cu_seqlens[idx] + for idx in range(cu_seqlens.size(0) - 1) + ]).item() + + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Decide whether to use SWA or not by layer index. + if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: + use_sliding_windows = False + + if not use_sliding_windows: + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + query_states = query_states.unsqueeze(0) + key_states = key_states.unsqueeze(0) + value_states = value_states.unsqueeze(0) + return attn_output + + +def replace_llama_attention_class(): + LLAMA_ATTENTION_CLASSES['flash_attention_2'] = LlamaFlashAttention2ForPackedTraining + print('Replace LLAMA_ATTENTION_CLASSES to support packed training!!') diff --git a/internvl/patch/llama_rmsnorm_monkey_patch.py b/internvl/patch/llama_rmsnorm_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..68efb44468bef51066d6633732b4f15de6daa04c --- /dev/null +++ b/internvl/patch/llama_rmsnorm_monkey_patch.py @@ -0,0 +1,23 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import transformers + + +def replace_llama_rmsnorm_with_fused_rmsnorm(): + try: + from functools import partial + + from apex.normalization import FusedRMSNorm + LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa + transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm + print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm') + except ImportError: + # using the normal LlamaRMSNorm + pass + except Exception: + print('discovered apex but it failed to load, falling back to LlamaRMSNorm') + pass diff --git a/internvl/patch/pad_data_collator.py b/internvl/patch/pad_data_collator.py new file mode 100644 index 0000000000000000000000000000000000000000..f803ed5f76b5a2f621dcc5fc1242438872558c5c --- /dev/null +++ b/internvl/patch/pad_data_collator.py @@ -0,0 +1,206 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import numpy as np +import torch + +IGNORE_INDEX = -100 + + +def pad_data_collator(features, pad_id=0): + + first = features[0] + batch = {} + + batch_lens = [feat['input_ids'].shape for feat in features] + max_item_length = max(batch_lens)[0] + for idx in range(len(features)): + feat = features[idx] + temp_input_ids = torch.LongTensor([pad_id] * max_item_length) + temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] + feat['input_ids'] = temp_input_ids + temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) + temp_labels[:feat['labels'].shape[0]] = feat['labels'] + feat['labels'] = temp_labels + feat['attention_mask'] = feat['input_ids'].ne(pad_id) + + # Special handling for labels. + # Ensure that tensor is created with the correct type + # (it should be automatically the case, but let's make sure of it.) + if 'label' in first and first['label'] is not None: + label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] + dtype = torch.long if isinstance(label, int) else torch.float + batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) + elif 'label_ids' in first and first['label_ids'] is not None: + if isinstance(first['label_ids'], torch.Tensor): + batch['labels'] = torch.stack([f['label_ids'] for f in features]) + else: + dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float + batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) + + # Handling of all other possible keys. + # Again, we will use the first element to figure out which key/values are not None for this model. + for k, v in first.items(): + if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str): + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.tensor(np.stack([f[k] for f in features])) + else: + batch[k] = torch.tensor([f[k] for f in features]) + return batch + + +# def concat_pad_data_collator(features, max_item_length=None, pad_id=0): + +# first = features[0] +# batch = {} + +# batch_lens = [feat['input_ids'].shape for feat in features] +# max_item_length = max_item_length or max(batch_lens)[0] +# for idx in range(len(features)): +# feat = features[idx] +# temp_input_ids = torch.LongTensor([pad_id] * max_item_length) +# temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] +# feat['input_ids'] = temp_input_ids +# temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) +# temp_labels[:feat['labels'].shape[0]] = feat['labels'] +# feat['labels'] = temp_labels +# feat['attention_mask'] = feat['input_ids'].ne(pad_id) + +# if 'position_ids' in feat: +# temp_position_ids = torch.LongTensor([pad_id] * max_item_length) +# temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids'] +# feat['position_ids'] = temp_position_ids + +# if 'loss_weight' in feat: +# temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length) +# temp_loss_weight[:feat['loss_weight'].shape[0]] = feat['loss_weight'] +# feat['loss_weight'] = temp_loss_weight + +# # Special handling for labels. +# # Ensure that tensor is created with the correct type +# # (it should be automatically the case, but let's make sure of it.) +# if 'label' in first and first['label'] is not None: +# label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] +# dtype = torch.long if isinstance(label, int) else torch.float +# batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) +# elif 'label_ids' in first and first['label_ids'] is not None: +# if isinstance(first['label_ids'], torch.Tensor): +# batch['labels'] = torch.stack([f['label_ids'] for f in features]) +# else: +# dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float +# batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) + +# # Handling of all other possible keys. +# # Again, we will use the first element to figure out which key/values are not None for this model. +# for k, v in first.items(): +# if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \ +# v is not None and not isinstance(v, str): +# if isinstance(v, torch.Tensor): +# batch[k] = torch.stack([f[k] for f in features]) +# elif isinstance(v, np.ndarray): +# batch[k] = torch.tensor(np.stack([f[k] for f in features])) +# else: +# batch[k] = torch.tensor([f[k] for f in features]) +# if k in ('pixel_values', 'image_flags'): +# if isinstance(v, torch.Tensor): +# batch[k] = torch.concat([f[k] for f in features]) +# elif isinstance(v, np.ndarray): +# batch[k] = torch.concat(np.stack([f[k] for f in features])) +# else: +# batch[k] = torch.concat([f[k] for f in features]) +# return batch + +def concat_pad_data_collator(features, pad_id=0): + # import pdb + # pdb.set_trace() + first = features[0] + batch = {} + batch_lens = [len(feat['input_ids']) for feat in features] + max_item_length = max(batch_lens) + for idx in range(len(features)): + feat = features[idx] + temp_input_ids = torch.LongTensor([pad_id] * max_item_length) + temp_input_ids[:len(feat['input_ids'])] = torch.Tensor(feat['input_ids']) + feat['input_ids'] = temp_input_ids + # temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) + # temp_labels[:feat['labels'].shape[0]] = feat['labels'] + # feat['labels'] = temp_labels + feat['attention_mask'] = feat['input_ids'].ne(pad_id) + + # Special handling for labels. + # Ensure that tensor is created with the correct type + # (it should be automatically the case, but let's make sure of it.) + if 'label' in first and first['label'] is not None: + label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] + dtype = torch.long if isinstance(label, int) else torch.float + batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) + elif 'label_ids' in first and first['label_ids'] is not None: + if isinstance(first['label_ids'], torch.Tensor): + batch['labels'] = torch.stack([f['label_ids'] for f in features]) + else: + dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float + batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) + + # Handling of all other possible keys. + # Again, we will use the first element to figure out which key/values are not None for this model. + for k, v in first.items(): + if k not in ('label', 'label_ids', 'pixel_values', 'image_flags', 'mask_values', 'masks_flags') and \ + v is not None and not isinstance(v, str): + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.tensor(np.stack([f[k] for f in features])) + else: + batch[k] = torch.tensor([f[k] for f in features]) + if k in ('pixel_values', 'image_flags', 'mask_values', 'masks_flags'): + if isinstance(v, torch.Tensor): + batch[k] = torch.concat([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.concat(np.stack([f[k] for f in features])) + else: + batch[k] = torch.concat([f[k] for f in features]) + return batch + + +def dpo_concat_pad_data_collator(features, pad_id=0): + + first = features[0] + batch = {} + + for prefix in ['chosen_', 'rejected_']: + batch_lens = [feat[f'{prefix}input_ids'].shape[0] for feat in features] + max_item_length = max(batch_lens) + for idx in range(len(features)): + feat = features[idx] + temp_input_ids = torch.LongTensor([pad_id] * max_item_length) + temp_input_ids[:feat[f'{prefix}input_ids'].shape[0]] = feat[f'{prefix}input_ids'] + feat[f'{prefix}input_ids'] = temp_input_ids + temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) + temp_labels[:feat[f'{prefix}labels'].shape[0]] = feat[f'{prefix}labels'] + feat[f'{prefix}labels'] = temp_labels + feat[f'{prefix}attention_mask'] = feat[f'{prefix}input_ids'].ne(pad_id) + + # Handling of all other possible keys. + # Again, we will use the first element to figure out which key/values are not None for this model. + for k, v in first.items(): + if k not in ('pixel_values', 'image_flags') and \ + v is not None and not isinstance(v, str): + if isinstance(v, torch.Tensor): + batch[k] = torch.stack([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.tensor(np.stack([f[k] for f in features])) + else: + batch[k] = torch.tensor([f[k] for f in features]) + if k in ('pixel_values', 'image_flags'): + if isinstance(v, torch.Tensor): + batch[k] = torch.concat([f[k] for f in features]) + elif isinstance(v, np.ndarray): + batch[k] = torch.concat(np.stack([f[k] for f in features])) + else: + batch[k] = torch.concat([f[k] for f in features]) + return batch diff --git a/internvl/patch/phi3_packed_training_patch.py b/internvl/patch/phi3_packed_training_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..5cdb60e85b37006a288e5cc205f8d75774acb092 --- /dev/null +++ b/internvl/patch/phi3_packed_training_patch.py @@ -0,0 +1,105 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import torch +from flash_attn.flash_attn_interface import flash_attn_varlen_func +from internvl.model.phi3.modeling_phi3 import (PHI3_ATTENTION_CLASSES, + Phi3FlashAttention2) + + +class Phi3FlashAttention2ForPackedTraining(Phi3FlashAttention2): + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 + query_states = query_states.squeeze(0) + key_states = key_states.squeeze(0) + value_states = value_states.squeeze(0) + cu_seqlens = attention_mask.squeeze(0) + + with torch.no_grad(): + max_seqlen = max([ + cu_seqlens[idx+1] - cu_seqlens[idx] + for idx in range(cu_seqlens.size(0) - 1) + ]).item() + + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Decide whether to use SWA or not by layer index. + if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: + use_sliding_windows = False + + if not use_sliding_windows: + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + query_states = query_states.unsqueeze(0) + key_states = key_states.unsqueeze(0) + value_states = value_states.unsqueeze(0) + return attn_output + + +def replace_phi3_attention_class(): + PHI3_ATTENTION_CLASSES['flash_attention_2'] = Phi3FlashAttention2ForPackedTraining + print('Replace PHI3_ATTENTION_CLASSES to support packed training!!') diff --git a/internvl/patch/qwen2_packed_training_patch.py b/internvl/patch/qwen2_packed_training_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..0b55888f0b2da0571d6c47a9cc7c3906fa918281 --- /dev/null +++ b/internvl/patch/qwen2_packed_training_patch.py @@ -0,0 +1,106 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import torch +from flash_attn.flash_attn_interface import flash_attn_varlen_func +from transformers.models.qwen2.modeling_qwen2 import (QWEN2_ATTENTION_CLASSES, + Qwen2FlashAttention2) + + +# Modified from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 +class Qwen2FlashAttention2ForPackedTraining(Qwen2FlashAttention2): + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 + query_states = query_states.squeeze(0) + key_states = key_states.squeeze(0) + value_states = value_states.squeeze(0) + cu_seqlens = attention_mask.squeeze(0) + + with torch.no_grad(): + max_seqlen = max([ + cu_seqlens[idx+1] - cu_seqlens[idx] + for idx in range(cu_seqlens.size(0) - 1) + ]).item() + + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Decide whether to use SWA or not by layer index. + if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: + use_sliding_windows = False + + if not use_sliding_windows: + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + query_states = query_states.unsqueeze(0) + key_states = key_states.unsqueeze(0) + value_states = value_states.unsqueeze(0) + return attn_output + + +def replace_qwen2_attention_class(): + QWEN2_ATTENTION_CLASSES['flash_attention_2'] = Qwen2FlashAttention2ForPackedTraining + print('Replace QWEN2_ATTENTION_CLASSES to support packed training!!') diff --git a/internvl/patch/train_dataloader_patch.py b/internvl/patch/train_dataloader_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..eade21e51ca8011f379eca5c4a20f41b31e4db58 --- /dev/null +++ b/internvl/patch/train_dataloader_patch.py @@ -0,0 +1,53 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import datasets +import torch +import transformers +from torch.utils.data import DataLoader +from transformers.trainer import is_datasets_available, seed_worker + + +def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError('Trainer: training requires a train_dataset.') + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description='training') + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description='training') + + dataloader_params = { + 'batch_size': self._train_batch_size, + 'collate_fn': data_collator, + 'num_workers': self.args.dataloader_num_workers, + 'pin_memory': self.args.dataloader_pin_memory, + 'persistent_workers': self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params['sampler'] = self._get_train_sampler() + dataloader_params['drop_last'] = self.args.dataloader_drop_last + dataloader_params['worker_init_fn'] = seed_worker + + if self.args.use_packed_ds: + return DataLoader(train_dataset, **dataloader_params) + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + +def replace_train_dataloader(): + transformers.Trainer.get_train_dataloader = get_train_dataloader + # print('Replace train dataloader!!') diff --git a/internvl/patch/train_sampler_patch.py b/internvl/patch/train_sampler_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..35d616e055052108c246f09ebe24f692bfb2eb8d --- /dev/null +++ b/internvl/patch/train_sampler_patch.py @@ -0,0 +1,125 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from typing import List, Optional + +import torch +import transformers +from torch.utils.data import Dataset, Sampler +from transformers.tokenization_utils_base import BatchEncoding +from transformers.trainer import (LengthGroupedSampler, RandomSampler, + has_length) +from transformers.trainer_pt_utils import logger + + +# copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38 +def split_to_even_chunks(indices, lengths, num_chunks): + """ + Split a list of indices into `chunks` chunks of roughly equal lengths. + """ + + if len(indices) % num_chunks != 0: + return [indices[i::num_chunks] for i in range(num_chunks)] + + num_indices_per_chunk = len(indices) // num_chunks + + chunks = [[] for _ in range(num_chunks)] + chunks_lengths = [0 for _ in range(num_chunks)] + for index in indices: + shortest_chunk = chunks_lengths.index(min(chunks_lengths)) + chunks[shortest_chunk].append(index) + chunks_lengths[shortest_chunk] += lengths[index] + if len(chunks[shortest_chunk]) == num_indices_per_chunk: + chunks_lengths[shortest_chunk] = float('inf') + + return chunks + + +# copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88 +def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): + # We need to use torch for the random part as a distributed sampler will set the random seed for torch. + indices = torch.randperm(len(lengths), generator=generator) + megabatch_size = world_size * batch_size + megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] + megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] + megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] + + return [i for megabatch in megabatches for batch in megabatch for i in batch] + + +# modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99 +class LengthGroupedSampler(Sampler): + r""" + Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while + keeping a bit of randomness. + """ + + def __init__( + self, + batch_size: int, + world_size: int, + dataset: Optional[Dataset] = None, + lengths: Optional[List[int]] = None, + model_input_name: Optional[str] = None, + generator=None, + ): + if dataset is None and lengths is None: + raise ValueError('One of dataset and lengths must be provided.') + + self.batch_size = batch_size + if lengths is None: + model_input_name = model_input_name if model_input_name is not None else 'input_ids' + if ( + not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) + or model_input_name not in dataset[0] + ): + raise ValueError( + 'Can only automatically infer lengths for datasets whose items are dictionaries with an ' + f"'{model_input_name}' key." + ) + lengths = [len(feature[model_input_name]) for feature in dataset] + elif isinstance(lengths, torch.Tensor): + logger.info( + 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...' + ) + lengths = lengths.tolist() + self.world_size = world_size + self.lengths = lengths + self.generator = generator + + def __len__(self): + return len(self.lengths) + + def __iter__(self): + indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) + return iter(indices) + + +# patch trainer +def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + # Build the sampler. + if self.args.group_by_length: + lengths = [] + for dataset in self.train_dataset.datasets: + lengths = lengths + dataset.length + model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None + return LengthGroupedSampler( + self.args.train_batch_size, + world_size=self.args.world_size * self.args.gradient_accumulation_steps, + # self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + lengths=lengths, + model_input_name=model_input_name, + ) + else: + return RandomSampler(self.train_dataset) + + +def replace_train_sampler(): + transformers.Trainer._get_train_sampler = _get_train_sampler + # print('Replace train sampler!!') diff --git a/internvl/train/__init__.py b/internvl/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e207feb58fdcb609895b072fae31371385c522a7 --- /dev/null +++ b/internvl/train/__init__.py @@ -0,0 +1,5 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- diff --git a/internvl/train/__pycache__/__init__.cpython-310.pyc b/internvl/train/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fdfa4701c7b7fb6e7225b3b00f4652e12b7c3aa Binary files /dev/null and b/internvl/train/__pycache__/__init__.cpython-310.pyc differ diff --git a/internvl/train/__pycache__/__init__.cpython-39.pyc b/internvl/train/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98184b3eb624f02fd106965a292e21eac278ed98 Binary files /dev/null and b/internvl/train/__pycache__/__init__.cpython-39.pyc differ diff --git a/internvl/train/__pycache__/constants.cpython-310.pyc b/internvl/train/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..939dae928a517aa0a11a187ec900646395da3e6a Binary files /dev/null and b/internvl/train/__pycache__/constants.cpython-310.pyc differ diff --git a/internvl/train/__pycache__/constants.cpython-39.pyc b/internvl/train/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6f87104aa41c876676d2a86ba0aaf5d9330c58f Binary files /dev/null and b/internvl/train/__pycache__/constants.cpython-39.pyc differ diff --git a/internvl/train/__pycache__/dataset.cpython-310.pyc b/internvl/train/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b80205294b6c2eb6d5525bb32dd36b6b0d0a3331 Binary files /dev/null and b/internvl/train/__pycache__/dataset.cpython-310.pyc differ diff --git a/internvl/train/__pycache__/dataset.cpython-39.pyc b/internvl/train/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7ad06cc1e5372a37aa1c6af4a8101f435a16d53 Binary files /dev/null and b/internvl/train/__pycache__/dataset.cpython-39.pyc differ diff --git a/internvl/train/__pycache__/dataset_packed.cpython-39.pyc b/internvl/train/__pycache__/dataset_packed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24f3f57af4f1b31784a981640cf2a31b82a3d1cd Binary files /dev/null and b/internvl/train/__pycache__/dataset_packed.cpython-39.pyc differ diff --git a/internvl/train/__pycache__/trainer_monkey_patch.cpython-39.pyc b/internvl/train/__pycache__/trainer_monkey_patch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5150c2b2fb90e37f173558f2e14f16bb9923d82 Binary files /dev/null and b/internvl/train/__pycache__/trainer_monkey_patch.cpython-39.pyc differ diff --git a/internvl/train/constants.py b/internvl/train/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..86f6c967b1b5071b140b2e16bbd4d352d9373204 --- /dev/null +++ b/internvl/train/constants.py @@ -0,0 +1,21 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +IMG_CONTEXT_TOKEN = '' +IMG_START_TOKEN = '' +IMG_END_TOKEN = '' +QUAD_START_TOKEN = '' +QUAD_END_TOKEN = '' +REF_START_TOKEN = '' +REF_END_TOKEN = '' +BOX_START_TOKEN = '' +BOX_END_TOKEN = '' +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073) +CLIP_STD = (0.2686295, 0.2613025, 0.2757711) +SIGLIP_MEAN = (0.5, 0.5, 0.5) +SIGLIP_STD = (0.5, 0.5, 0.5) diff --git a/internvl/train/dataset.py b/internvl/train/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..56cfa77518f9fc1ed96de9300d98c932f1a55ffa --- /dev/null +++ b/internvl/train/dataset.py @@ -0,0 +1,921 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import io +import matplotlib.pyplot as plt +from transformers.trainer_pt_utils import LabelSmoother + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +import os +import random +import re +from collections import Counter +from typing import Dict + +import cv2 +import imageio +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as T +import transformers +from decord import VideoReader +from internvl.conversation import get_conv_template +from PIL import Image +from torch.utils.data import ConcatDataset, WeightedRandomSampler +from torchvision.transforms.functional import InterpolationMode + +from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD, + IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, + SIGLIP_MEAN, SIGLIP_STD) + +try: + from petrel_client.client import Client + from petrel_client.common.config import Config +except ImportError as E: + print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.') +import sys + + +def calculate_ngram_repetition(text, n): + words = text.split() + ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)] + ngram_counts = Counter(ngrams) + total_ngrams = len(ngrams) + repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1) + return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0 + + +def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10): + for conversation in conversations: + if conversation['from'] == 'gpt': + model_answer = conversation['value'] + repeat_ratio = calculate_ngram_repetition(model_answer, ngram) + if repeat_ratio > repeat_threshold: + raise Exception + + +def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): + if sample in ['rand', 'middle']: # uniform sampling + acc_samples = min(num_frames, vlen) + # split the video into `acc_samples` intervals, and sample from each interval. + intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) + ranges = [] + for idx, interv in enumerate(intervals[:-1]): + ranges.append((interv, intervals[idx + 1] - 1)) + if sample == 'rand': + try: + frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] + except: + frame_indices = np.random.permutation(vlen)[:acc_samples] + frame_indices.sort() + frame_indices = list(frame_indices) + elif fix_start is not None: + frame_indices = [x[0] + fix_start for x in ranges] + elif sample == 'middle': + frame_indices = [(x[0] + x[1]) // 2 for x in ranges] + else: + raise NotImplementedError + + if len(frame_indices) < num_frames: # padded with last frame + padded_frame_indices = [frame_indices[-1]] * num_frames + padded_frame_indices[:len(frame_indices)] = frame_indices + frame_indices = padded_frame_indices + elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps + output_fps = float(sample[3:]) + duration = float(vlen) / input_fps + delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents + frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) + frame_indices = np.around(frame_seconds * input_fps).astype(int) + frame_indices = [e for e in frame_indices if e < vlen] + if max_num_frames > 0 and len(frame_indices) > max_num_frames: + frame_indices = frame_indices[:max_num_frames] + # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) + else: + raise ValueError + return frame_indices + + +def read_frames_gif( + video_path, num_frames, sample='rand', fix_start=None, + client=None, min_num_frames=4 +): + if 's3://' in video_path: + video_bytes = client.get(video_path) + gif = imageio.get_reader(io.BytesIO(video_bytes)) + else: + gif = imageio.get_reader(video_path) + vlen = len(gif) + + t_num_frames = np.random.randint(min_num_frames, num_frames + 1) + frame_indices = get_frame_indices( + t_num_frames, vlen, sample=sample, fix_start=fix_start + ) + frames = [] + for index, frame in enumerate(gif): + if index in frame_indices: + frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8) + frame = Image.fromarray(frame) + frames.append(frame) + return frames + + +def read_frames_decord( + video_path, num_frames, sample='rand', fix_start=None, + client=None, clip=None, min_num_frames=4 +): + if 's3://' in video_path: + video_bytes = client.get(video_path) + video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) + else: + video_reader = VideoReader(video_path, num_threads=1) + vlen = len(video_reader) + fps = video_reader.get_avg_fps() + duration = vlen / float(fps) + if clip: + start, end = clip + duration = end - start + vlen = int(duration * fps) + start_index = int(start * fps) + + # t_num_frames = min(max(int(duration * sample_fps), min_num_frames), num_frames) + t_num_frames = np.random.randint(min_num_frames, num_frames + 1) + + frame_indices = get_frame_indices( + t_num_frames, vlen, sample=sample, fix_start=fix_start, + input_fps=fps + ) + if clip: + frame_indices = [f + start_index for f in frame_indices] + frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8 + frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])] + return frames + + +def extract_frame_number(filename): + # Extract the numeric part from the filename using regular expressions + match = re.search(r'_(\d+).jpg$', filename) + return int(match.group(1)) if match else -1 + + +def sort_frames(frame_paths): + # Extract filenames from each path and sort by their numeric part + return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x))) + + +def read_frames_folder( + video_path, num_frames, sample='rand', fix_start=None, + client=None, clip=None, min_num_frames=4 +): + if 's3://' in video_path: + image_list = sort_frames(client.list(video_path)) + frames = [] + for image in image_list: + fp = os.path.join(video_path, image) + frame = Image.open(io.BytesIO(client.get(fp))) + frames.append(frame) + else: + image_list = sort_frames(list(os.listdir(video_path))) + frames = [] + for image in image_list: + fp = os.path.join(video_path, image) + frame = Image.open(fp).convert('RGB') + frames.append(frame) + vlen = len(frames) + + t_num_frames = np.random.randint(min_num_frames, num_frames + 1) + + if vlen > t_num_frames: + frame_indices = get_frame_indices( + t_num_frames, vlen, sample=sample, fix_start=fix_start + ) + frames = [frames[i] for i in frame_indices] + return frames + + +class WeightedConcatDataset(ConcatDataset): + def __init__(self, datasets, weights): + super().__init__(datasets) + self.weights = torch.DoubleTensor(weights) + self.total_size = sum(len(d) for d in datasets) + self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True) + + def __iter__(self): + return iter(self.sampler) + + def __len__(self): + return self.total_size + + +def pil_loader(img_str): + buff = io.BytesIO(img_str) + img = Image.open(buff) + return img.convert('RGB') + + +class TCSLoader(object): + + def __init__(self, conf_path, sc_config_key='sensecore'): + print(f'[TCSLoader] config_path: {conf_path}') + print('--> before Client(conf_path)') + self.client = Client(conf_path) + self.sc_config_key = sc_config_key + print('--> after Client(conf_path)') + + def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', clip=None): + if image_type == 'image': + img_value_str = self.client.get(fn) + img = pil_loader(img_value_str) + return img + + elif image_type == 'video': + if fn.endswith('/'): + frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, + client=self.client, sample=sample) + elif fn.endswith('.gif'): + frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, + client=self.client, sample=sample) + else: + frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, + client=self.client, sample=sample, clip=clip) + return frames + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +def simulate_jpeg_degradation(quality): + def jpeg_degrade(img): + with io.BytesIO() as output: + img.convert('RGB').save(output, format='JPEG', quality=quality) + output.seek(0) # Move the reading cursor to the start of the stream + img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory + return img_jpeg + return jpeg_degrade + + +# Define the JPEG compression quality range, pre-create all JPEG compression functions +qualities = list(range(75, 101)) +jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities} + + +def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'): + if normalize_type == 'imagenet': + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + elif normalize_type == 'clip': + MEAN, STD = CLIP_MEAN, CLIP_STD + elif normalize_type == 'siglip': + MEAN, STD = SIGLIP_MEAN, SIGLIP_STD + else: + raise NotImplementedError + if is_train: # use data augumentation + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + else: + if pad2square is False: # now we use this transform function by default + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + else: + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + + return transform + + +def preprocess( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + if not text_only: + new_conversations = [] + for conversation in conversations: + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + conversation = conversation.replace('', image_tokens, 1) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False if group_by_length or use_packed_ds else 'max_length', + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO + + # Mask targets. Only compute loss on the assistant outputs. + sep = conv.sep + conv.roles[1] + ': ' + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split(conv.sep2) + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + for i, turn in enumerate(turns): + if turn == '': + break + turn_len = len(tokenizer(turn).input_ids) + + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + # "-2" is hardcoded for the Llama tokenizer to make the offset correct. + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + instruction_len -= 1 + + # Ignore the user instructions + target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID + cur_len += turn_len + + if i != 0 and not tokenizer.legacy: + # The legacy and non-legacy modes handle special tokens differently + cur_len -= 1 + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + logger.info(tokenizer.decode(z)) + exit() + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + print( + f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' + f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' + ) + sys.stdout.flush() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def preprocess_mpt( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + if not text_only: + new_conversations = [] + for conversation in conversations: + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + conversation = conversation.replace('', image_tokens, 1) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False if group_by_length or use_packed_ds else 'max_length', + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + # Mask targets. Only compute loss on the assistant outputs. + sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) + + turns = conversation.split(conv.sep) + re_turns = [conv.sep.join(turns[:3])] # system + user + gpt + for conv_idx in range(3, len(turns), 2): + re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt + cur_len = 0 + target[:cur_len] = IGNORE_TOKEN_ID + for i, turn in enumerate(re_turns): + if turn == '': + break + turn_len = len(tokenizer(turn).input_ids) + 1 + + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + instruction_len = len(tokenizer(parts[0]).input_ids) + + # Ignore the user instructions + target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID + # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) + # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) + # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + print( + f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' + f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' + ) + sys.stdout.flush() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def preprocess_phi3( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + if not text_only: + new_conversations = [] + for conversation in conversations: + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + conversation = conversation.replace('', image_tokens, 1) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + tokenizer.padding_side = 'right' + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False if group_by_length or use_packed_ds else 'max_length', + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + # Mask targets. Only compute loss on the assistant outputs. + sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|> + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) + + turns = conversation.split(conv.sep) + re_turns = [conv.sep.join(turns[:3])] # system + user + gpt + for conv_idx in range(3, len(turns), 2): + re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID + endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') + target[target == endoftext_id] = IGNORE_TOKEN_ID + + for i, turn in enumerate(re_turns): + if turn == '': + break + if i == 0: + turn_len = len(tokenizer(turn).input_ids) + else: + turn_len = len(tokenizer(turn).input_ids) - 1 + parts = turn.split(sep) + if len(parts) != 2: + break + parts[0] += sep + + if i == 0: + instruction_len = len(tokenizer(parts[0]).input_ids) - 1 + else: + instruction_len = len(tokenizer(parts[0]).input_ids) - 2 + + # Ignore the user instructions + target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID + # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) + # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) + # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) + cur_len += turn_len + + target[cur_len:] = IGNORE_TOKEN_ID + + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + print(repr(tokenizer.decode(z))) + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + print( + f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' + f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' + ) + sys.stdout.flush() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def preprocess_internlm( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + conv = get_conv_template(template_name) + roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} + + # Apply prompt templates + conversations = [] + for i, source in enumerate(sources): + if roles[source[0]['from']] != conv.roles[0]: + # Skip the first one if it is not from human + source = source[1:] + + conv.messages = [] + for j, sentence in enumerate(source): + role = roles[sentence['from']] + assert role == conv.roles[j % 2], f'{i}' + sentence['value'] = sentence['value'].strip() + conv.append_message(role, sentence['value']) + conversations.append(conv.get_prompt()) + + if not text_only: + new_conversations = [] + for conversation in conversations: + for i in range(num_image): + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' + conversation = conversation.replace('', image_tokens, 1) + new_conversations.append(conversation) + conversations = new_conversations + + # Tokenize conversations + input_ids = tokenizer( + conversations, + return_tensors='pt', + padding=False if group_by_length or use_packed_ds else 'max_length', + max_length=tokenizer.model_max_length, + truncation=True, + ).input_ids + targets = input_ids.clone() + + for conversation, target in zip(conversations, targets): + total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id + cur_len = 1 + target[:cur_len] = IGNORE_TOKEN_ID # + parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n + info = parts[0] + conv.roles[1] + temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的 + target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID + cur_len = cur_len + temp_len + + for index in range(1, len(parts) - 1): + info = parts[index] + part1, part2 = info.split(conv.roles[0]) + temp_len = len(tokenizer(part1).input_ids) - 1 + cur_len = cur_len + temp_len + part = conv.roles[0] + part2 + conv.roles[1] + temp_len = len(tokenizer(part).input_ids) - 1 + target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID + cur_len = cur_len + temp_len + last_info = parts[-1] + temp_len = len(tokenizer(last_info).input_ids) - 1 + cur_len = cur_len + temp_len + + target[cur_len:] = IGNORE_TOKEN_ID + if False: # Inspect and check the correctness of masking + z = target.clone() + z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) + print(repr(tokenizer.decode(z))) + + if cur_len < tokenizer.model_max_length: + if cur_len != total_len: + target[:] = IGNORE_TOKEN_ID + print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') + sys.stdout.flush() + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def preprocess_internvl2_5( + template_name, + sources, + tokenizer: transformers.PreTrainedTokenizer, + num_image_token_list: list, + text_only: bool = False, + group_by_length: bool = False, + use_packed_ds: bool = False, + ds_name: str = None, + num_image: int = 1 +) -> Dict: + assert len(sources) == 1, 'process only the first conversations' + conversations = sources[0] + + if conversations[0]['from'] == 'system': + system_prompt = conversations[0]['value'] + conversations = conversations[1:] # remove system prompt + else: + conv = get_conv_template(template_name) + system_prompt = conv.system_message + # system_prompt = None + + if not text_only: + new_conversations = [] + current_image_idx = 0 + for conversation in conversations: + if conversation['from'] == 'human': + image_cnt = conversation['value'].count('') + for i in range(image_cnt): + if current_image_idx == num_image: + break + image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}' + conversation['value'] = conversation['value'].replace('', image_tokens, 1) + current_image_idx += 1 + new_conversations.append(conversation) + conversations = new_conversations + assert current_image_idx == num_image, f'{current_image_idx} != {num_image}' + + batches, roles = [], [] + if system_prompt is not None: + batches.append(f'<|im_start|>system\n{system_prompt}<|im_end|>\n') + roles.append('system') + for conversation in conversations: + if conversation['from'] == 'human': + batches.append(f'<|im_start|>user\n{conversation["value"]}<|im_end|>\n') + roles.append('human') + elif conversation['from'] == 'gpt': + batches.append(f'<|im_start|>assistant\n{conversation["value"]}<|im_end|>\n') + roles.append('gpt') + else: + raise NotImplementedError + + add_bos_token = getattr(tokenizer, 'add_bos_token', False) + if add_bos_token: # for InternLM series + batches[0] = tokenizer.bos_token + batches[0] + + # Tokenize conversations + input_ids = tokenizer( + batches, + return_tensors='np', + padding=False, + max_length=tokenizer.model_max_length, + truncation=False, + ).input_ids + + if add_bos_token: # for InternLM series + input_ids = [item[1:] for item in input_ids] + + final_input_ids, final_targets = [], [] + ignore_ids = tokenizer('<|im_start|>assistant\n', return_tensors='np').input_ids[0] + ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0] + for role, input_id in zip(roles, input_ids): + final_input_ids.append(input_id) + if role == 'system' or role == 'human': + final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID)) # ignore + elif role == 'gpt': + target = input_id.copy() + target[:ignore_len] = IGNORE_TOKEN_ID # ignore loss for `<|im_start|>assistant\n` + target[-1:] = IGNORE_TOKEN_ID # ignore loss for `\n` + final_targets.append(target) + else: + raise NotImplementedError + input_ids = torch.tensor(np.concatenate(final_input_ids))[:tokenizer.model_max_length] + targets = torch.tensor(np.concatenate(final_targets))[:tokenizer.model_max_length] + + padding = False if group_by_length or use_packed_ds else True + if padding: + current_length = input_ids.size(0) + padding_length = tokenizer.model_max_length - current_length + input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_token_id) + targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID) + + input_ids = input_ids.unsqueeze(0) + targets = targets.unsqueeze(0) + + return dict( + input_ids=input_ids, + labels=targets, + attention_mask=input_ids.ne(tokenizer.pad_token_id), + ) + + +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') + return best_ratio + + +def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, return_ratio=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + if return_ratio: + return processed_images, target_aspect_ratio + return processed_images + + +def dynamic_preprocess_mask(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): + # import pdb + length, orig_height, orig_width = image.shape + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + # print(target_aspect_ratio) + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # resize the image + + tensor_images = image.unsqueeze(1) # 添加一个维度作为单通道 + # pdb.set_trace() + resized_images = F.interpolate(tensor_images, size=(target_height, target_width), mode='bilinear', align_corners=False) #(1792,1344) + resized_images = resized_images > 0 + # print(resized_images.shape) + # 然后像 PIL 那样裁剪图像块 + processed_images = [] + for i in range(blocks): + top = (i // (target_width // image_size)) * image_size + left = (i % (target_width // image_size)) * image_size + bottom = top + image_size + right = left + image_size + # 使用 tensor 切片进行裁剪 + split_img = resized_images[..., top:bottom, left:right] # 这里使用...来保持通道这一维度 + processed_images.append(split_img) + # plt.imshow(split_img.sum(0).squeeze()) + # plt.savefig(f'/workdir/guantongkun/12490719/eef5a3b245897c9f4335463fb12fed35/work_dirs/{i}_mask.jpg', dpi=600) + # pdb.set_trace() + # 最后,如果您需要,可以对处理过的图像list进行任何后续操作 + # 例如,convert回通道为最后维度的形式,如果是单通道的话 + processed_images = [img.squeeze(1) for img in processed_images] + + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = F.interpolate(tensor_images, size=(image_size, image_size), mode='bilinear', align_corners=False).squeeze(1) + thumbnail_img = thumbnail_img > 0 + # Image.fromarray(thumbnail_img.cpu().numpy().astype(np.uint8)) + processed_images.append(thumbnail_img) + return processed_images diff --git a/internvl/train/dataset_packed.py b/internvl/train/dataset_packed.py new file mode 100644 index 0000000000000000000000000000000000000000..3b54ed47524b054ab54d737457626da785723fe7 --- /dev/null +++ b/internvl/train/dataset_packed.py @@ -0,0 +1,634 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import bisect +import copy +import logging +from collections import defaultdict +from typing import List, Union + +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import IterableDataset, get_worker_info +from transformers.trainer_pt_utils import LabelSmoother + +from .constants import IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +class PackedDataset(IterableDataset): + def __init__( + self, + tokenizer, + data_rank, + data_world_size, + datasets: List, + dataset_weight: List[int] = None, + num_images_expected: int = 6, + max_packed_tokens: int = 32768, + max_buffer_size: int = 100, + log_freq: int = 1000000, + strict_mode: bool = False, + debug_mode: bool = False, + replacement: bool = True, + allow_overflow: bool = True, + allow_empty_data: bool = False, + allow_deduplicated_ds_name: bool = False, + ): + super().__init__() + self.tokenizer = tokenizer + self.data_rank = data_rank + self.data_world_size = data_world_size + self.datasets = datasets + self.num_images_expected = num_images_expected + self.max_buffer_size = max_buffer_size + self.log_freq = log_freq + self.strict_mode = strict_mode + self.debug_mode = debug_mode + self.replacement = replacement + self.allow_overflow = allow_overflow + self.allow_empty_data = allow_empty_data + + self.max_packed_tokens = max_packed_tokens + + self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN) + self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) + + assert self.img_start_token_id != self.tokenizer.unk_token_id + assert self.img_token_id != self.tokenizer.unk_token_id + assert self.img_end_token_id != self.tokenizer.unk_token_id + + if dataset_weight is None: + dataset_weight = [1] * len(datasets) + self.dataset_type = [d.dataset_type for d in self.datasets] + + self.datasets_orig = datasets + self.dataset_weight_orig = [w / sum(dataset_weight) for w in dataset_weight] + + self.datasets = [ds for ds in self.datasets_orig] + self.dataset_weight = [w for w in self.dataset_weight_orig] + + # lazy init + self.worker_id = None + self.worker_state_key = None + self.dataset_iter_list = None + self._state_dict = { + 'sample_info': {d.ds_name:0 for d in self.datasets}, + } + + self.worker_custom_infos = None + + ds_name_list = [d.ds_name for d in self.datasets] + if not allow_deduplicated_ds_name: + assert len(ds_name_list) == len(set(ds_name_list)), f'deduplicated ds_name: {ds_name_list}' + + for ds in self.datasets: + if ds.max_num_images > self.num_images_expected: + logger.warning(f'{ds.max_num_images=} of {ds.ds_name} is larger than {self.num_images_expected=}') + ds.max_num_images = num_images_expected + + if ds.max_tokens > self.max_packed_tokens: + logger.warning(f'{ds.max_tokens=} of {ds.ds_name} is larger than {self.max_packed_tokens=}') + ds.max_tokens = self.max_packed_tokens + + self._state_dict[ds.ds_name] = {} + + if get_rank() == 0: + logger.info( + f'Loaded dataset to pack: {ds_name_list}, ' + f'{self.num_images_expected=}, {self.max_packed_tokens=}, ' + f'{self.replacement=}, {self.allow_overflow=}', + ) + + temp = [] + for ds, ds_w in zip(self.datasets, self.dataset_weight): + temp.append(f'{ds.ds_name:<25}: {ds_w*100:.2f}%') + temp = '\n'.join(temp) + logger.info( + f'Sampling prob for each dataset:\n{temp}' + ) + + if self.allow_empty_data: + logger.warning('allow_empty_data is enabled, note that empty data may be generated!') + + def load_state_dict(self, state_dict, custom_infos=None): + + self.worker_custom_infos = custom_infos + + self._state_dict.update(state_dict) + for ds in self.datasets: + if ds.ds_name in self._state_dict: + ds.load_state_dict(self._state_dict[ds.ds_name]) + logger.info(f'{ds.ds_name=} is resumed.') + else: + logger.warning(f'{ds.ds_name=} is not resumed.') + + def _should_log(self): + worker_id = 0 if get_worker_info() is None else get_worker_info().id + num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers + + worker_id = num_workers * get_rank() + worker_id + num_workers = num_workers * get_world_size() + + return worker_id == 0 + + def next_data(self, current_dataset_idx): + while True: + try: + current_sample = next(self.dataset_iter_list[current_dataset_idx]) + break # Exit loop if successful + except StopIteration: + if self.replacement: + # logger.info(f'[Worker id {self.worker_id}] Dataset {self.datasets[current_dataset_idx].ds_name} is exhausted, restart it.') + try: + self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx]) + current_sample = next(self.dataset_iter_list[current_dataset_idx]) + break + except: + # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') + self.datasets.pop(current_dataset_idx) + self.dataset_iter_list.pop(current_dataset_idx) + self.dataset_weight.pop(current_dataset_idx) + + if len(self.datasets) == 0: + raise StopIteration + current_dataset_idx = np.random.choice(len(self.datasets)) + else: + # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') + self.datasets.pop(current_dataset_idx) + self.dataset_iter_list.pop(current_dataset_idx) + self.dataset_weight.pop(current_dataset_idx) + + if len(self.datasets) == 0: + raise StopIteration + current_dataset_idx = np.random.choice(len(self.datasets)) + except: + logger.error('Unexpected error!') + if len(self.datasets) == 0: + raise StopIteration + current_dataset_idx = np.random.choice(len(self.datasets)) + + current_ds_name = self.datasets[current_dataset_idx].ds_name + current_sample['type_ids'] = torch.zeros_like(current_sample['input_ids']) + current_dataset_idx + + if self.worker_state_key not in self._state_dict[current_ds_name]: + self._state_dict[current_ds_name][self.worker_state_key] = {} + + meta_info = current_sample.pop('meta_info', {}) + self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info) + self._state_dict['sample_info'][self.datasets[current_dataset_idx].ds_name] += 1 + return current_sample + + def find_buffer(self, buffer_list, new_sample): + # NOTE: use `bisect` to search might be faster + + find = False + find_idx = -1 + num_images_current = new_sample['pixel_values'].size(0) + for buffer_idx, buffer in enumerate(buffer_list): + num_images_buffer = buffer['pixel_values'].size(0) + if num_images_buffer + num_images_current <= self.num_images_expected: + num_merged_tokens = new_sample['input_ids'].size(0) + buffer['input_ids'].size(0) + + if num_merged_tokens <= self.max_packed_tokens: + find = True + find_idx = buffer_idx + break + + if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2: + find = True + find_idx = buffer_idx + + if find: + return buffer_list.pop(find_idx) + return None + + def update_buffer(self, buffer, new_sample): + if buffer is None: + new_sample['data_index'] = torch.zeros_like(new_sample['input_ids']) + return new_sample + + new_sample['data_index'] = torch.ones_like(new_sample['input_ids']) + buffer['data_index'][-1].item() + + assert buffer.keys() == new_sample.keys() + for k in buffer: + buffer[k] = torch.cat([buffer[k], new_sample[k]]) + return buffer + + @staticmethod + def check_valid(sample_to_check, min_active_tokens_ratio=1/256): + num_ignore_tokens = (sample_to_check['labels'] == IGNORE_TOKEN_ID).sum() + num_tokens = sample_to_check['labels'].numel() + return (1 - num_ignore_tokens / num_tokens) > min_active_tokens_ratio + + @staticmethod + def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id): + if buffer['input_ids'].size(0) <= max_tokens: + return [buffer] + + def _image_is_splitted(input_ids, cut_idx): + is_image_start = input_ids[cut_idx].item() == img_start_token_id + is_image_token = input_ids[cut_idx].item() == img_token_id + is_image_end = input_ids[cut_idx].item() == img_end_token_id + return is_image_start or is_image_token or is_image_end + + def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx): + assert (right_idx is None) == (right_img_idx is None) + + left_sample = {} + right_sample = {} if right_idx is not None else None + for k in sample_to_split: + if k in ['input_ids', 'labels', 'attention_mask', 'position_ids', 'data_index', 'type_ids']: + left_sample[k] = sample_to_split[k][:left_idx] + if right_sample is not None: + right_sample[k] = sample_to_split[k][right_idx:] + elif k in ['pixel_values', 'image_flags']: + left_sample[k] = sample_to_split[k][:left_img_idx] + if right_sample is not None: + right_sample[k] = sample_to_split[k][right_img_idx:] + else: + raise NotImplementedError(f'find unsupported keys: {k} from {sample_to_split.keys()}') + return left_sample, right_sample + + splitted_buffer = [] + while buffer['input_ids'].size(0) > max_tokens: + img_start_idx_list = (buffer['input_ids'] == img_start_token_id).nonzero().squeeze(1).tolist() + img_end_idx_list = (buffer['input_ids'] == img_end_token_id).nonzero().squeeze(1).tolist() + assert len(img_start_idx_list) == len(img_end_idx_list) + + if _image_is_splitted(buffer['input_ids'], max_tokens): + cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens) + if buffer['input_ids'][max_tokens] == img_start_token_id: + assert max_tokens == img_start_idx_list[cut_idx] + cut_left_idx = img_start_idx_list[cut_idx] + cut_left_img_idx = cut_idx + else: + cut_left_idx = img_start_idx_list[cut_idx - 1] + cut_left_img_idx = cut_idx - 1 + cut_right_idx = cut_left_idx + cut_right_img_idx = cut_left_img_idx + else: + cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens) + if cut_img_idx < len(img_start_idx_list): + cut_right_idx = img_start_idx_list[cut_img_idx] + cut_right_img_idx = cut_img_idx + else: + cut_right_idx = None + cut_right_img_idx = None + + cut_left_idx = max_tokens + cut_left_img_idx = cut_right_img_idx if cut_right_img_idx is not None else buffer['pixel_values'].size(0) + + left, right = _split( + sample_to_split=buffer, + left_idx=cut_left_idx, + left_img_idx=cut_left_img_idx, + right_idx=cut_right_idx, + right_img_idx=cut_right_img_idx, + ) + + assert (left['input_ids'] == img_end_token_id).sum() == (left['input_ids'] == img_start_token_id).sum() == left['pixel_values'].size(0) + if right is not None: + assert (right['input_ids'] == img_end_token_id).sum() == (right['input_ids'] == img_start_token_id).sum() == right['pixel_values'].size(0) + + if left['pixel_values'].size(0) >= 1 and PackedDataset.check_valid(left): + splitted_buffer.append(left) + + if right is None or right['pixel_values'].size(0) == 0: + break + + buffer = right + if buffer['input_ids'].size(0) <= max_tokens and PackedDataset.check_valid(buffer): + splitted_buffer.append(buffer) + break + + logger.debug( + f'split a sample into {len(splitted_buffer)} samples, ' + f'current max_tokens={max_tokens}' + ) + return splitted_buffer + + def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer): + # NOTE: in-place operation + + splitted_buffer = PackedDataset.split_buffer( + buffer=buffer, + max_tokens=self.max_packed_tokens, + img_start_token_id=self.img_start_token_id, + img_token_id=self.img_token_id, + img_end_token_id=self.img_end_token_id, + ) + + for each_buffer in splitted_buffer: + if each_buffer['pixel_values'].size(0) > self.num_images_expected: + logger.error( + f"Find a sample with {each_buffer['pixel_values'].size(0)} images, " + f'which exceeds {self.num_images_expected}' + ) + continue + + if each_buffer['input_ids'].size(0) >= self.max_packed_tokens: + assert each_buffer['input_ids'].size(0) == self.max_packed_tokens + buffer_max_len_list.append(each_buffer) + continue + + find_idx = len(buffer_list) + num_images_new_sample = each_buffer['pixel_values'].size(0) + for buffer_idx in range(len(buffer_list)): + if buffer_list[buffer_idx]['pixel_values'].size(0) < num_images_new_sample: + find_idx = buffer_idx + break + buffer_list.insert(find_idx, each_buffer) + + for i in range(1, len(buffer_list)): + assert buffer_list[i-1]['pixel_values'].size(0) >= buffer_list[i]['pixel_values'].size(0) + + return buffer_list, buffer_max_len_list + + def pad_buffer(self, buffer): + if buffer['pixel_values'].size(0) == self.num_images_expected: + return buffer + + num_pad_images = self.num_images_expected - buffer['pixel_values'].size(0) + pad_images = torch.stack([ + torch.zeros_like(buffer['pixel_values'][0]) + for _ in range(num_pad_images) + ]) + pad_image_flags = torch.tensor([0] * num_pad_images, dtype=torch.long) + + buffer['pixel_values'] = torch.cat([buffer['pixel_values'], pad_images]) + buffer['image_flags'] = torch.cat([buffer['image_flags'], pad_image_flags]) + + return buffer + + def postprocess_buffer(self, buffer, custom_infos=None): + buffer['worker_state_key'] = self.worker_state_key + buffer['worker_state_dict'] = self._state_dict + if custom_infos is not None: + buffer['custom_infos'] = {self.worker_state_key: copy.deepcopy(custom_infos)} + return buffer + + def print_log(self, iter_idx, buffer_list): + if iter_idx % self.log_freq != 0: + return + + if self._should_log(): + logger.info( + f"{iter_idx=}, {len(buffer_list)=}, {self._state_dict['sample_info']}" + ) + + def __iter__(self): + iter_idx = 0 + buffer_list = [] + buffer_max_len_list = [] + + if self._should_log(): + logger.info(f'Begin to iter, {len(buffer_list)=}') + + worker_id = 0 if get_worker_info() is None else get_worker_info().id + num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers + + worker_id = num_workers * self.data_rank + worker_id + num_workers = num_workers * self.data_world_size + + rng = np.random.default_rng(seed=worker_id) + + # reset states of each dataset + self.worker_id = worker_id + self.worker_state_key = f'work_state_{self.worker_id}' + self.datasets = [d for d in self.datasets_orig] + self.dataset_weight = [w for w in self.dataset_weight_orig] + self.dataset_iter_list = [iter(d) for d in self.datasets] + + for ds in self.datasets: + # if not isinstance(ds, (ImageTextPairDataset, InterleavedDataset)): + ds.worker_id = worker_id + ds.worker_state_key = f'work_state_{self.worker_id}' + ds.num_workers = num_workers + if self._should_log() and worker_id == 0: + logger.info(f'set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}') + + if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos: + custom_infos = self.worker_custom_infos[self.worker_state_key] + # buffer list + if 'buffer_list' in custom_infos and isinstance(custom_infos['buffer_list'], list): + buffer_list = custom_infos['buffer_list'] + if self._should_log() and worker_id == 0: + logger.info(f'[{self.worker_state_key}] load buffer list --> {len(buffer_list)=}') + # other infos + + # reset + self.worker_custom_infos = None + + logger.debug( + f'{self.__class__.__name__} Rank {self.data_rank} ' + f'Worker {worker_id} begin to load data' + ) + + while True: + self.dataset_weight = [w / sum(self.dataset_weight) for w in self.dataset_weight] + current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight) + + try: + current_sample = self.next_data(current_dataset_idx) + except: + logger.info(f'All datasets are exhausted, begin to empty the buffer_list ({len(buffer_list)=})') + while len(buffer_list) > 0: + if self.strict_mode: + yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0))) + else: + yield self.postprocess_buffer(buffer_list.pop(0)) + logger.info(f'buffer_list is empty! ({len(buffer_list)=})') + return + + buffer = self.find_buffer(buffer_list, current_sample) + buffer = self.update_buffer(buffer, current_sample) + buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer) + + while len(buffer_max_len_list) > 0: + if buffer_max_len_list[0]['pixel_values'].size(0) != self.max_packed_tokens: + logger.debug( + f'num tokens of a buffer exceed {self.max_packed_tokens=}, ' + f"yield a sample with {buffer_max_len_list[0]['pixel_values'].size(0)} images" + ) + if self.strict_mode and buffer_max_len_list[0]['pixel_values'].size(0) != self.num_images_expected: + # buffer_max_len_list.pop(0) + yield self.postprocess_buffer(self.pad_buffer(buffer_max_len_list.pop(0)), {'buffer_list': buffer_list}) + else: + yield self.postprocess_buffer(buffer_max_len_list.pop(0), {'buffer_list': buffer_list}) + + while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) > self.num_images_expected: + logger.error( + f"num images of a buffer ({buffer_list[0]['pixel_values'].size(0)}) " + f'is larger than num_images_expected({self.num_images_expected})' + ) + buffer_list.pop(0) + + while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) == self.num_images_expected: + if self.debug_mode: + debug_data = self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) + while True: + yield debug_data.copy() + + yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) + + while len(buffer_list) > self.max_buffer_size: + logger.debug( + f'Failed to pack data to exactly {self.num_images_expected} images, ' + f"yield a data sample with {buffer_list[0]['pixel_values'].size(0)} images." + ) + if self.strict_mode: + yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {'buffer_list': buffer_list}) + else: + yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) + + self.print_log(iter_idx=iter_idx, buffer_list=buffer_list) + iter_idx += 1 + + @staticmethod + def get_cu_seqlens_and_indexes( + data_index: torch.LongTensor, # (seq_len,) + input_ids: torch.LongTensor, # (seq_len,) + labels: torch.LongTensor, # (seq_len,) + len2weight: callable, + ): + indexes = [] + cu_seqlens = [0] + loss_weight = [] + + start = data_index.min() + end = data_index.max() + 1 + for i in range(start, end): + num_tokens = (data_index == i).sum().item() + indexes.extend(list(range(num_tokens))) + cu_seqlens.append(cu_seqlens[-1] + num_tokens) + assert num_tokens > 0 + + curr_data_index = data_index[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] + assert (curr_data_index == i).all(), data_index + + curr_labels = labels[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] + num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item() + loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens) + + assert len(indexes) == data_index.size(0), f'{len(indexes)=}, {data_index.size(0)=}' + + loss_weight = torch.tensor(loss_weight, dtype=torch.float32) + return cu_seqlens, indexes, loss_weight + + +WARNING_CNT = defaultdict(int) + + +def packed_collate_fn( + features, + data_collator, + len2weight: callable, + max_item_length: int, + micro_num: int = 1, + loss_reduction_all_gather: bool = False, + pad_id: int = 0, +): + if not isinstance(features, list): + features = [features] + + if len(features) > micro_num: + raise NotImplementedError(f'{len(features)=} > {micro_num=}') + + if len(features) < micro_num and WARNING_CNT['micro_num_warning'] < 5: + logger.warning( + f'{len(features)=} > {micro_num=}, ' + f'the features will be padded to satisfy micro_num requirement' + ) + WARNING_CNT['micro_num_warning'] += 1 + + # ensure that the len(features) is equal to the required micro_num + num_features = len(features) + while len(features) < micro_num: + features.append(copy.deepcopy(features[0])) + features[-1]['labels'] = torch.full_like(features[-1]['labels'], IGNORE_TOKEN_ID) + + indexes = [] + cu_seqlens = [] + cu_num_images_list = [0] + + worker_state_key_list = [] + worker_state_dict_list = [] + worker_state_custom_infos_list = [] + + batch_lens = [feat['input_ids'].shape for feat in features] + max_item_length = max_item_length or max(batch_lens)[0] + + num_samples = 0 + num_padding_tokens = 0 + for feat_idx, feat in enumerate(features): + data_index = feat.pop('data_index') + curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes( + data_index=data_index, + input_ids=feat['input_ids'], + labels=feat['labels'], + len2weight=len2weight, + ) + + feat['loss_weight'] = curr_loss_weight + + if feat_idx < num_features: + num_samples += len(curr_cu_seqlens) - 1 + + if curr_cu_seqlens[-1] < max_item_length: + curr_cu_seqlens.append(max_item_length) + curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2]))) + + indexes.append(torch.tensor(curr_indexes, dtype=torch.long)) + cu_seqlens.append(torch.tensor(curr_cu_seqlens, dtype=torch.int32)) + + worker_state_key_list.append(feat.pop('worker_state_key')) + worker_state_dict_list.append(feat.pop('worker_state_dict')) + worker_state_custom_infos_list.append(feat.pop('custom_infos', None)) + + num_padding_tokens += (max_item_length - feat['input_ids'].size(0)) + cu_num_images_list.append(cu_num_images_list[-1] + feat['pixel_values'].size(0)) + + batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id) + # convert it to list in case it is converted into bf16 + batch['loss_weight'] = torch.where(batch['labels'] == IGNORE_TOKEN_ID, 0, batch['loss_weight']).tolist() + batch['attention_mask'] = torch.stack(cu_seqlens) + batch['loss_reduction_all_gather'] = loss_reduction_all_gather + batch['statistics'] = torch.tensor( + [ + num_samples, + num_padding_tokens, + batch['image_flags'].numel() - batch['image_flags'].sum().item(), + ], + dtype=torch.long, + ) + batch.pop('type_ids') + return batch diff --git a/internvl/train/trainer_dpo.py b/internvl/train/trainer_dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a3ed624a43900f902a9dd48b7182b8a21bf6d2 --- /dev/null +++ b/internvl/train/trainer_dpo.py @@ -0,0 +1,287 @@ +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +from typing import Dict, List, Literal, Optional, Tuple, Union + +import torch +from torch import nn +from torch.utils.data import ConcatDataset +from trl import DPOTrainer +from trl.trainer.utils import RunningMoments, pad_to_length + + +def _map(self, *args, **kwargs): + return self + + +ConcatDataset.map = _map + + +class MultimodalDPOTrainer(DPOTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.loss_type != 'bco_pair' and 'bco_pair' in self.loss_type: + self.running = RunningMoments(self.accelerator) + + @staticmethod + def concatenated_inputs( + batch: Dict[str, Union[List, torch.LongTensor]], + is_encoder_decoder: bool = False, + is_vision_model: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + ) -> Dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + is_encoder_decoder: Whether the model is an encoder-decoder model. + label_pad_token_id: The label pad token id. + padding_value: The padding value to use for the concatenated inputs_ids. + device: The device for the concatenated inputs. + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch['chosen_labels'].shape[1], batch['rejected_labels'].shape[1]) + else: + max_length = max(batch['chosen_input_ids'].shape[1], batch['rejected_input_ids'].shape[1]) + + for k in batch: + if k.startswith('chosen') and isinstance(batch[k], torch.Tensor): + if 'labels' in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith('_input_ids'): + pad_value = padding_value + elif k.endswith('_attention_mask'): + pad_value = 0 + concatenated_key = k.replace('chosen', 'concatenated') + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith('rejected') and isinstance(batch[k], torch.Tensor): + if 'labels' in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith('_input_ids'): + pad_value = padding_value + elif k.endswith('_attention_mask'): + pad_value = 0 + concatenated_key = k.replace('rejected', 'concatenated') + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch['concatenated_input_ids'] = batch['prompt_input_ids'].repeat(2, 1).to(device=device) + concatenated_batch['concatenated_attention_mask'] = ( + batch['prompt_attention_mask'].repeat(2, 1).to(device=device) + ) + + if 'pixel_values' in batch: + concatenated_batch['pixel_values'] = batch['pixel_values'].repeat(2, 1, 1, 1) + concatenated_batch['image_flags'] = batch['image_flags'].repeat(2) + + return concatenated_batch + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + is_vision_model=self.is_vision_model, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch['chosen_labels'].shape[0] + + model_kwargs = {} + + if self.is_encoder_decoder: + model_kwargs['labels'] = concatenated_batch['concatenated_labels'] + model_kwargs['decoder_input_ids'] = concatenated_batch.pop('concatenated_decoder_input_ids', None) + + if self.is_vision_model: + model_kwargs['pixel_values'] = concatenated_batch['pixel_values'] + model_kwargs['pixel_attention_mask'] = concatenated_batch['pixel_attention_mask'] + + if self.aux_loss_enabled: + model_kwargs['output_router_logits'] = True + + outputs = model( + input_ids=concatenated_batch['concatenated_input_ids'], + attention_mask=concatenated_batch['concatenated_attention_mask'], + pixel_values=concatenated_batch['pixel_values'], + image_flags=concatenated_batch['image_flags'], + use_cache=False, + **model_kwargs, + ) + all_logits = outputs.logits + + all_logps, size_completion = self.get_batch_logps( + all_logits, + concatenated_batch['concatenated_labels'], + # average_log_prob=self.loss_type == "ipo", + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = concatenated_batch['concatenated_labels'].clone() + nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + + if self.loss_type == 'ipo': + all_logps = all_logps / size_completion + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) + + def _prepare_deepspeed(self, model): + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepspeed_plugin.deepspeed_config + if config_kwargs['zero_optimization']['stage'] == 3: + print('Enable DPOTrainer._prepare_deepspeed') + return super()._prepare_deepspeed(model) + + print('Disable DPOTrainer._prepare_deepspeed') + for param in model.parameters(): + param.requires_grad = False + + model.eval() + model = model.to(self.accelerator.device) + return model + + def get_batch_loss_metrics( + self, + model, + batch: Dict[str, Union[List, torch.LongTensor]], + train_eval: Literal['train', 'eval'] = 'train', + ): + """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + # if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model + if ( + 'reference_chosen_logps' in batch + and 'reference_rejected_logps' in batch + and self.args.rpo_alpha is not None + ): + reference_chosen_logps = batch['reference_chosen_logps'] + reference_rejected_logps = batch['reference_rejected_logps'] + else: + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + _, + ) = self.concatenated_forward(self.model, batch) + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + _, + ) = self.concatenated_forward(self.ref_model, batch) + + if ',' in self.loss_type: + loss_type = self.loss_type + loss_type_list = loss_type.split(',') + + losses, chosen_rewards, rejected_rewards = 0, 0, 0 + for curr_type in loss_type_list: + self.loss_type = curr_type + curr_losses, curr_chosen_rewards, curr_rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + curr_weight = getattr(self.args, f'{curr_type}_loss_weight') + losses = losses + curr_losses * curr_weight + chosen_rewards = chosen_rewards + curr_chosen_rewards * curr_weight + rejected_rewards = rejected_rewards + curr_rejected_rewards * curr_weight + + self.loss_type = loss_type + else: + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + if self.args.rpo_alpha is not None: + # losses = losses * self.args.rpo_alpha + policy_nll_loss + losses = losses + policy_nll_loss * self.args.rpo_alpha + + prefix = 'eval_' if train_eval == 'eval' else '' + metrics[f'{prefix}rewards/chosen'] = chosen_rewards.mean().cpu() + metrics[f'{prefix}rewards/rejected'] = rejected_rewards.mean().cpu() + metrics[f'{prefix}rewards/accuracies'] = reward_accuracies.mean().cpu() + metrics[f'{prefix}rewards/margins'] = (chosen_rewards - rejected_rewards).mean().cpu() + metrics[f'{prefix}logps/rejected'] = policy_rejected_logps.detach().mean().cpu() + metrics[f'{prefix}logps/chosen'] = policy_chosen_logps.detach().mean().cpu() + metrics[f'{prefix}logits/rejected'] = policy_rejected_logits.detach().mean().cpu() + metrics[f'{prefix}logits/chosen'] = policy_chosen_logits.detach().mean().cpu() + if self.args.rpo_alpha is not None: + metrics[f'{prefix}nll_loss'] = policy_nll_loss.detach().mean().cpu() + + if self.aux_loss_enabled: + return losses.mean() + getattr(model.config, 'router_aux_loss_coef', 0.0) * aux_loss, metrics + + return losses.mean(), metrics diff --git a/internvl/train/trainer_monkey_patch.py b/internvl/train/trainer_monkey_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..e5c501c27fc17404b9ba992cf578f7a475e3abe3 --- /dev/null +++ b/internvl/train/trainer_monkey_patch.py @@ -0,0 +1,246 @@ +import json +import os + +import torch +import torch.nn as nn +import transformers +from transformers import Trainer, logging +from transformers.trainer import is_sagemaker_mp_enabled + +logger = logging.get_logger(__name__) + + +def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer): + if var_name.startswith('internvl.'): + var_name = var_name[len('internvl.'):] + if var_name in ('query_tokens', 'logit_scale',): + return 0 + if var_name.startswith('clip_projector.'): + return vit_num_max_layer + if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \ + var_name == 'text_projection': + return llama_num_max_layer + if var_name.startswith('vision_model.'): + if 'embeddings.' in var_name: + return 0 + if 'layers.' in var_name: + var_name = var_name.split('layers.')[-1] + layer_id = int(var_name.split('.')[0]) + return layer_id + 1 + if var_name.startswith('qllama.'): + if 'embed_tokens' in var_name: + return 0 + if 'layers.' in var_name: + var_name = var_name.split('layers.')[-1] + layer_id = int(var_name.split('.')[0]) + return layer_id + 1 + else: + return llama_num_max_layer + return 0 + + +def param_classification(name): + if name.startswith('internvl.'): + name = name[len('internvl.'):] + if name in ['query_tokens', 'text_projection', 'logit_scale']: + return 'qllama' + elif name.startswith('vision_model.'): + return 'vit' + elif name.startswith('qllama.'): + return 'qllama' + elif name.startswith('clip_projector.'): + return 'vit' + elif name.startswith('clip_projector2.'): + return 'qllama' + elif name.startswith('itm_head.'): + return 'qllama' + else: + return 'other' + + +def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + # import pdb; pdb.set_trace() + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + parameter_groups = {} + try: # for stage2 model + vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2 + qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2 + except: # for stage3 model + vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2 + qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2 + print('vit_num_layers:', vit_num_layers) + print('qllama_num_layers:', qllama_num_layers) + + vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0)) + qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0)) + qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0)) + print('vit_layer_decay_rate:', vit_layer_decay_rate) + print('qllama_layer_decay_rate:', qllama_layer_decay_rate) + print('qllama_lr_scale:', qllama_lr_scale) + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith('.bias'): + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = self.args.weight_decay + + cls = param_classification(name) + layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers) + group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name) + if group_name not in parameter_groups: + if cls == 'vit': + scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1) + elif cls == 'qllama': + scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1) + scale = scale * qllama_lr_scale + else: + scale = 1.0 + scale = min(1.0, scale) + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.args.learning_rate, + } + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + + rank = torch.distributed.get_rank() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + print('Param groups = %s' % json.dumps(to_display, indent=2)) + + optimizer_grouped_parameters = list(parameter_groups.values()) + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == 'Adam8bit': + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f'skipped {module}: {skipped / 2 ** 20}M params') + manager.register_module_override(module, 'weight', {'optim_bits': 32}) + logger.debug(f'bitsandbytes: will optimize {module} in fp32') + logger.info(f'skipped: {skipped / 2 ** 20}M params') + + if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer + +def create_optimizer_custom(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + # import pdb; pdb.set_trace() + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + parameter_groups = {} + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith('.bias'): + group_name = 'no_decay' + this_weight_decay = 0. + else: + group_name = 'decay' + this_weight_decay = self.args.weight_decay + + if 'ocr_mlp' in name or 'upsample' in name: + group_name = '%s_%s' % ('modify', group_name) + elif 'vision_model' in name: + group_name = '%s_%s' % ('vit', group_name) + else: + group_name = '%s_%s' % ('base', group_name) + + if group_name not in parameter_groups: + if 'ocr_mlp' in name or 'upsample' in name: + scale = 1.0 + elif 'vision_model' in name: + scale = 0.05 + else: + scale = 1.0 + + parameter_groups[group_name] = { + 'weight_decay': this_weight_decay, + 'params': [], + 'param_names': [], + 'lr_scale': scale, + 'group_name': group_name, + 'lr': scale * self.args.learning_rate, + } + parameter_groups[group_name]['params'].append(param) + parameter_groups[group_name]['param_names'].append(name) + + rank = torch.distributed.get_rank() + if rank == 0: + to_display = {} + for key in parameter_groups: + to_display[key] = { + 'param_names': parameter_groups[key]['param_names'], + 'lr_scale': parameter_groups[key]['lr_scale'], + 'lr': parameter_groups[key]['lr'], + 'weight_decay': parameter_groups[key]['weight_decay'], + } + print('Param groups = %s' % json.dumps(to_display, indent=2)) + + + optimizer_grouped_parameters = list(parameter_groups.values()) + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == 'Adam8bit': + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f'skipped {module}: {skipped / 2 ** 20}M params') + manager.register_module_override(module, 'weight', {'optim_bits': 32}) + logger.debug(f'bitsandbytes: will optimize {module} in fp32') + logger.info(f'skipped: {skipped / 2 ** 20}M params') + + if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer + + +def replace_create_optimizer(): + print('Replace original create_optimizer with custom create_optimizer') + # transformers.Trainer.create_optimizer = create_optimizer + transformers.Trainer.create_optimizer = create_optimizer_custom diff --git a/resnet50/__init__.py b/resnet50/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5455d84f91037c842c9184e9ca64720aadf6dcd2 --- /dev/null +++ b/resnet50/__init__.py @@ -0,0 +1,5 @@ +from .model import build + + +def build_model(args): + return build(args) \ No newline at end of file diff --git a/resnet50/__pycache__/__init__.cpython-310.pyc b/resnet50/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..132654c9f4bcd31af1a0213144d4c708e241f413 Binary files /dev/null and b/resnet50/__pycache__/__init__.cpython-310.pyc differ diff --git a/resnet50/__pycache__/backbone.cpython-310.pyc b/resnet50/__pycache__/backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6039b96ba94844ca55c948363a0fae273c108aec Binary files /dev/null and b/resnet50/__pycache__/backbone.cpython-310.pyc differ diff --git a/resnet50/__pycache__/model.cpython-310.pyc b/resnet50/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..186fec8bfdb6980609fe1a8bea27c0d883834954 Binary files /dev/null and b/resnet50/__pycache__/model.cpython-310.pyc differ diff --git a/resnet50/backbone.py b/resnet50/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..63363e6f1c3db58644ae8120d62b5a9a5585ae27 --- /dev/null +++ b/resnet50/backbone.py @@ -0,0 +1,85 @@ +from collections import OrderedDict + +import os +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor_list): + xs = self.body(tensor_list) + return xs + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers=False, + dilation=False): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=False, norm_layer=nn.BatchNorm2d) + + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +def build_backbone(args): + backbone = Backbone('resnet50', train_backbone=True) + return backbone diff --git a/resnet50/model.py b/resnet50/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4309fcaf6290131df4b35e1deaf42760c49fe90b --- /dev/null +++ b/resnet50/model.py @@ -0,0 +1,164 @@ +import torch +import torch.nn.functional as F +from torch import nn +from .backbone import build_backbone +import pdb +import numpy as np +from typing import Optional + +class TokenOCR(nn.Module): + def __init__(self, backbone): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + + """ + super().__init__() + self.language_embedding = nn.Embedding(92553, 2048, padding_idx=2) + for p in self.parameters(): + p.requires_grad = False + + self.backbone = backbone + init_tau=np.log(10) + init_b=-2.71 + # self.t_prime = nn.Parameter(torch.ones([]) * init_tau) + # self.b = nn.Parameter(torch.ones([]) * init_b) + self.kb = True + self.upsample = nn.Sequential( + nn.ConvTranspose2d( + in_channels=2048, + out_channels=512, + kernel_size=4, + stride=2, + padding=1, + bias=False + ), + nn.SyncBatchNorm(512), + nn.ConvTranspose2d( + in_channels=512, + out_channels=512, + kernel_size=4, + stride=2, + padding=1, + bias=False + ), + nn.SyncBatchNorm(512), + ) + self.ocr_mlp = nn.Sequential( + nn.Linear(512, 2048), + nn.GELU(), + nn.Linear(2048, 2048) + ) + + def forward(self, + pixel_values: torch.FloatTensor, + input_ids: torch.LongTensor = None, + image_flags: Optional[torch.LongTensor] = None, + mask_values: Optional[torch.LongTensor] = None, + masks_flags: Optional[torch.LongTensor] = None, + mask_nums: Optional[torch.LongTensor] = None, + ): + image_flags = image_flags.squeeze(-1) + try: + input_embeds = self.language_embedding(input_ids).clone() + except: + print('error'*1000) + import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() + vit_embeds, vit_embeds_shape = self.extract_feature_custom(pixel_values) #(vit_batch_size, 16*16, 2048) + nb, nl, nd = vit_embeds.shape + h, w = vit_embeds_shape + vit_embeds = vit_embeds.reshape(nb, h, w, nd) + vit_embeds = vit_embeds.split(list(image_flags)) #[(vit_batch_size / B, h, w, C)]*B + vit_batch_size = pixel_values.shape[0] + + B, N, C = input_embeds.shape + try: + assert sum(image_flags) == mask_values.shape[0] + except: + print((mask_values.shape, image_flags, mask_nums)) + + mask_values = torch.nn.functional.interpolate(mask_values.float(), size=(h, w), mode='bilinear', align_corners=False) #(128, 128) + masks = mask_values.split(list(image_flags)) #[(vit_batch_size / B, N, 448, 448)]*B + + + masks_flags = masks_flags.chunk(B) + token_features = [] + input_embedings = [] + masked_input_ids = [] + masked_zero_bools = [] + for i, vit_embed in enumerate(vit_embeds): + current_token = masks_flags[i].sum() + mask = masks[i] + limit_num = mask.shape[1] + mask = mask.permute(1,0,2,3).reshape(limit_num, -1) > 0 + max_cluster_index = mask.sum(-1) + zero_bool = max_cluster_index != 0 + # import pdb; pdb.set_trace() + mask[~zero_bool] = 1 #for addressing bflost16 bug + new_max_cluster_index = mask.sum(-1) + mask = mask / new_max_cluster_index.unsqueeze(-1) + token_feature = torch.matmul(mask.to(vit_embed), vit_embed.reshape(-1, vit_embed.shape[-1])) + token_features.extend(token_feature) + input_embedings.extend(input_embeds[i, :]) + masked_input_ids.extend(input_ids[i, zero_bool]) + masked_zero_bools.append(zero_bool) + + masked_zero_bools = torch.cat(masked_zero_bools) + token_features = torch.stack(token_features) + input_embedings= torch.stack(input_embedings) + + loss2 = F.mse_loss(token_features, input_embedings, reduction='none')[masked_zero_bools].sum(1).sqrt().mean() + token_features = token_features / token_features.norm(dim=1, keepdim=True) + input_embedings = input_embedings / input_embedings.norm(dim=1, keepdim=True) + # cosine similarity as logits + similarity = F.cosine_similarity(token_features, input_embedings, dim=1) + loss1 = (1 - similarity[masked_zero_bools]).mean() + # loss_d = loss1 + loss2 + # if rank == 0: + # print(f'loss1:{loss_d}') + + ###siglip + # masked_input_ids = torch.stack(masked_input_ids) + # label_matrix = (masked_input_ids.unsqueeze(0) == masked_input_ids.unsqueeze(1)).int() + # label_matrix = 2 * label_matrix - 1 + # if self.kb: + # logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() + self.b.to(input_embedings.device) + # else: + # logits = (input_embedings[masked_zero_bools] @ token_features[masked_zero_bools].t()) * self.t_prime.to(input_embedings.device).exp() - 8.9375 + # loss_s = -torch.sum(F.logsigmoid(label_matrix * logits)) / logits.shape[0] + # if rank == 0: + # print(f'loss2:{loss_s}') + return loss1, loss2 + + def forward_tokenocr(self, pixel_values): + vit_embeds = self.backbone(pixel_values) + vit_embeds = vit_embeds['0'] + h, w = vit_embeds.shape[2], vit_embeds.shape[3] + vit_embeds = self.upsample(vit_embeds) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-2] * vit_embeds.shape[-1]) + vit_embeds = self.ocr_mlp(vit_embeds.permute(0, 2, 1)) + return vit_embeds, (h*4, w*4) + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def build(args): + backbone = build_backbone(args) + model = TokenOCR(backbone) + return model diff --git a/tokenizer_path/special_tokens_map.json b/tokenizer_path/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..14c079b57d740ab21ceced78982ee102ff4fea48 --- /dev/null +++ b/tokenizer_path/special_tokens_map.json @@ -0,0 +1,48 @@ +{ + "additional_special_tokens": [ + "<|im_start|>", + "<|im_end|>", + "<|action_start|>", + "<|action_end|>", + "<|interpreter|>", + "<|plugin|>", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } + } + \ No newline at end of file diff --git a/tokenizer_path/tokenization_internlm2.py b/tokenizer_path/tokenization_internlm2.py new file mode 100644 index 0000000000000000000000000000000000000000..1be581da37ef678de65f2737493fc0ed7160446e --- /dev/null +++ b/tokenizer_path/tokenization_internlm2.py @@ -0,0 +1,235 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/tokenization_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for InternLM.""" +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'} + +PRETRAINED_VOCAB_FILES_MAP = {} + + +# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer +class InternLM2Tokenizer(PreTrainedTokenizer): + """ + Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ['input_ids', 'attention_mask'] + _auto_class = 'AutoTokenizer' + + def __init__( + self, + vocab_file, + unk_token='', + bos_token='', + eos_token='', + pad_token='', + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + decode_with_prefix_space=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.decode_with_prefix_space = decode_with_prefix_space + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + self._no_prefix_space_tokens = None + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def no_prefix_space_tokens(self): + if self._no_prefix_space_tokens is None: + vocab = self.convert_ids_to_tokens(list(range(self.vocab_size))) + self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith('▁')} + return self._no_prefix_space_tokens + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + @property + def bos_token_id(self) -> Optional[int]: + return self.sp_model.bos_id() + + @property + def eos_token_id(self) -> Optional[int]: + return self.sp_model.eos_id() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def _maybe_add_prefix_space(self, tokens, decoded): + if tokens and tokens[0] not in self.no_prefix_space_tokens: + return ' ' + decoded + else: + return decoded + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = '' + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += ' ' + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + out_string = self.clean_up_tokenization(out_string) + out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string) + return out_string[1:] + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f'Vocabulary path ({save_directory}) should be a directory') + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, 'wb') as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if self.add_bos_token: + bos_token_ids = [self.bos_token_id] + else: + bos_token_ids = [] + + output = bos_token_ids + token_ids_0 + + if token_ids_1 is not None: + output = output + token_ids_1 + + if self.add_eos_token: + output = output + [self.eos_token_id] + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] diff --git a/tokenizer_path/tokenization_internlm2_fast.py b/tokenizer_path/tokenization_internlm2_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..aa0fccbd0f1d029d79e19821f2edcb01b594537c --- /dev/null +++ b/tokenizer_path/tokenization_internlm2_fast.py @@ -0,0 +1,211 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization Fast class for InternLM.""" +import os +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple + +from tokenizers import Tokenizer, decoders, normalizers, processors +from tokenizers.models import BPE +from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS, + SentencePieceExtractor, + SpmConverter) +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.utils import logging + +from .tokenization_internlm2 import InternLM2Tokenizer + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'} + + +# Modified from transformers.convert_slow_tokenizer.LlamaConverter +class InternLM2Converter(SpmConverter): + handle_byte_fallback = True + + def vocab(self, proto): + vocab = [ + ('', 0.0), + ('', 0.0), + ('', 0.0), + ] + vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] + return vocab + + def unk_id(self, proto): + unk_id = 0 + return unk_id + + def decoder(self, replacement, add_prefix_space): + return decoders.Sequence( + [ + decoders.Replace('▁', ' '), + decoders.ByteFallback(), + decoders.Fuse(), + decoders.Strip(content=' ', left=1), + ] + ) + + def tokenizer(self, proto): + model_type = proto.trainer_spec.model_type + vocab_scores = self.vocab(proto) + # special tokens + added_tokens = self.original_tokenizer.added_tokens_decoder + for i in range(len(vocab_scores)): + piece, score = vocab_scores[i] + if i in added_tokens: + vocab_scores[i] = (added_tokens[i].content, score) + if model_type == 1: + raise RuntimeError('InternLM2 is supposed to be a BPE model!') + + elif model_type == 2: + _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores) + bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)} + tokenizer = Tokenizer( + BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True) + ) + tokenizer.add_special_tokens( + [ added_token for index, added_token in added_tokens.items()] + ) + else: + raise Exception( + "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" + ) + + return tokenizer + + def normalizer(self, proto): + normalizers_list = [] + if proto.normalizer_spec.add_dummy_prefix: + normalizers_list.append(normalizers.Prepend(prepend='▁')) + normalizers_list.append(normalizers.Replace(pattern=' ', content='▁')) + return normalizers.Sequence(normalizers_list) + + def pre_tokenizer(self, replacement, add_prefix_space): + return None + + +SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter + + +# Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast +class InternLM2TokenizerFast(PreTrainedTokenizerFast): + vocab_files_names = VOCAB_FILES_NAMES + slow_tokenizer_class = InternLM2Tokenizer + padding_side = 'left' + model_input_names = ['input_ids', 'attention_mask'] + _auto_class = 'AutoTokenizer' + + def __init__( + self, + vocab_file, + unk_token='', + bos_token='', + eos_token='', + pad_token='', + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + decode_with_prefix_space=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + sp_model_kwargs=sp_model_kwargs, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + decode_with_prefix_space=decode_with_prefix_space, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def update_post_processor(self): + """ + Updates the underlying post processor with the current `bos_token` and `eos_token`. + """ + bos = self.bos_token + bos_token_id = self.bos_token_id + if bos is None and self.add_bos_token: + raise ValueError('add_bos_token = True but bos_token = None') + + eos = self.eos_token + eos_token_id = self.eos_token_id + if eos is None and self.add_eos_token: + raise ValueError('add_eos_token = True but eos_token = None') + + single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" + pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' + 'tokenizer.' + ) + + if not os.path.isdir(save_directory): + logger.error(f'Vocabulary path ({save_directory}) should be a directory') + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) diff --git a/tokenizer_path/tokenizer.model b/tokenizer_path/tokenizer.model new file mode 100644 index 0000000000000000000000000000000000000000..6600712949ca9c4ffb50f25275993a21fba0b408 --- /dev/null +++ b/tokenizer_path/tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f868398fc4e05ee1e8aeba95ddf18ddcc45b8bce55d5093bead5bbf80429b48b +size 1477754 diff --git a/tokenizer_path/tokenizer_config.json b/tokenizer_path/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..d102412748394fbc89824d82a399710d3468e01f --- /dev/null +++ b/tokenizer_path/tokenizer_config.json @@ -0,0 +1,180 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92538": { + "content": "<|plugin|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92539": { + "content": "<|interpreter|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92540": { + "content": "<|action_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92541": { + "content": "<|action_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92542": { + "content": "<|im_end|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92543": { + "content": "<|im_start|>", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92544": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92545": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92546": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92547": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92548": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92549": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92550": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92551": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "92552": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "<|im_start|>", + "<|im_end|>", + "<|action_start|>", + "<|action_end|>", + "<|interpreter|>", + "<|plugin|>", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "auto_map": { + "AutoTokenizer": [ + "tokenization_internlm2.InternLM2Tokenizer", + null + ] + }, + "bos_token": "", + "chat_template": "{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "model_max_length": 16384, + "pad_token": "", + "tokenizer_class": "InternLM2Tokenizer", + "unk_token": "" + } + \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..074c42d61df8da4ffc6febcf68fd3fa5940b6f34 --- /dev/null +++ b/utils.py @@ -0,0 +1,194 @@ +import os +import torch +import torch.nn as nn +from internvl.train.dataset import build_transform, dynamic_preprocess +from internvl.model.internvl_chat import InternVisionModel, InternVLChatModel +from torchvision.utils import make_grid +import torchvision.transforms as T +import matplotlib.pyplot as plt +import torch.nn.functional as F +from transformers import AutoTokenizer, AutoModel, CLIPImageProcessor +import cv2 +from PIL import Image + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) +CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073) +CLIP_STD = (0.2686295, 0.2613025, 0.2757711) +SIGLIP_MEAN = (0.5, 0.5, 0.5) +SIGLIP_STD = (0.5, 0.5, 0.5) +IMG_CONTEXT_TOKEN = '' +IMG_START_TOKEN = '' +IMG_END_TOKEN = '' +QUAD_START_TOKEN = '' +QUAD_END_TOKEN = '' +REF_START_TOKEN = '' +REF_END_TOKEN = '' +BOX_START_TOKEN = '' +BOX_END_TOKEN = '' + +def load_model(config, state_dict): + vision_model = InternVisionModel(config.vision_config) + vit = InternVLChatModel(config, vision_model).to(torch.bfloat16) + vit.load_state_dict(state_dict, strict=False) + tok_embeddings = nn.Embedding(config.llm_config.vocab_size, config.llm_config.hidden_size, 2).to(torch.bfloat16) + tok_embeddings.weight = nn.Parameter(state_dict['language_model.model.tok_embeddings.weight']) + return vit, tok_embeddings + +def load_image(image_path): + transform = get_transform(is_train=False, image_size=448) + image = Image.open(image_path).convert('RGB') + images, target_aspect_ratio = dynamic_preprocess(image, min_num=1, max_num=12, + image_size=448, use_thumbnail=True, return_ratio=True) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values).to(torch.bfloat16) + return pixel_values, images, target_aspect_ratio + +def get_similarity_map(sm, shape, min_max=True, threshold=0.2): + B, N, H, W = sm.shape + sm = sm.reshape(B, N, H*W) + if min_max: + # min-max norm + sm = (sm - sm.min(2, keepdim=True)[0]) / (sm.max(2, keepdim=True)[0] - sm.min(2, keepdim=True)[0]) + else: + sm = sm > threshold + sm = sm.float() + # reshape + sm = sm.reshape(B, N, H, W).float() + # interpolate + sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear') + return sm + +def build_transform_R50(normalize_type='imagenet'): + if normalize_type == 'imagenet': + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + elif normalize_type == 'clip': + MEAN, STD = CLIP_MEAN, CLIP_STD + elif normalize_type == 'siglip': + MEAN, STD = SIGLIP_MEAN, SIGLIP_STD + else: + raise NotImplementedError + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + return transform + +def load_tokenizer(tokenizer_path): + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, add_eos_token=False, trust_remote_code=True, use_fast=False) + tokenizer.tokenizer_path = tokenizer_path + tokenizer.model_max_length = 8192 + token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, + QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN, + REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN] + num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True) + return tokenizer + + +def get_transform(is_train, image_size): + # Build transformation function + transform = build_transform(is_train=is_train, input_size=image_size, + pad2square=False, normalize_type='imagenet') + return transform + +def post_process(vit_embeds, target_aspect_ratio, model_type='VIT'): + if model_type in ["TokenOCR-4096-English-seg", "TokenOCR-2048-Bilingual-seg"]: + h = w = int(vit_embeds.shape[1] ** 0.5) + c = vit_embeds.shape[-1] + # vit_embeds_local = vit_embeds[:-1].reshape(-1, h, w, c).permute(0, 3, 1, 2) + if vit_embeds.shape[0] == 1: + vit_embeds_local = vit_embeds.reshape(-1, h, w, c).permute(0, 3, 1, 2) + else: + vit_embeds_local = vit_embeds[:-1].reshape(-1, h, w, c).permute(0, 3, 1, 2) + vit_embeds_local = make_grid(vit_embeds_local, nrow=target_aspect_ratio[0], padding=0, normalize=False) + vit_embeds_local = vit_embeds_local.permute(1,2,0) + H, W, C = vit_embeds_local.shape + vit_embeds_local = vit_embeds_local.reshape(H*W, C) + return vit_embeds_local, (H, W) + if 'R50' in model_type: + vit_embeds = vit_embeds.reshape(-1, vit_embeds.shape[-1]) + return vit_embeds, None + +def generate_similiarity_map(images, attn_map, all_bpe_strings, vis_list, target_aspect_ratio=(1,1), src_iamge_size=(1014, 1024), image_size=448): + # if isinstance(images, list): + # print("111111111") + # if len(images) == 1: + # images_vis = torch.stack([T.ToTensor()(image) for image in images]) + # else: + # images_vis = torch.stack([T.ToTensor()(image) for image in images[:-1]]) + + # images_vis = make_grid(images_vis, nrow=target_aspect_ratio[0], padding=0, normalize=False) + # print("image_size",image_size) + # print("target_aspect_ratio[0]",target_aspect_ratio[0]) + # print("target_aspect_ratio[1]",target_aspect_ratio[1]) + # target_width = image_size * target_aspect_ratio[0] + # target_height = image_size * target_aspect_ratio[1] + + # else: + # print("222222222") + + # images_vis = T.ToTensor()(images) + # target_height = images.size[1] + # target_width = images.size[0] + # print("images_vis",images_vis) + # print("target_height",images.size[1]) + # print("target_width",images.size[0]) + + images = images[0] + images_vis = T.ToTensor()(images) # images [] + print("images",images) + print("images_vis",images_vis) + target_height = images.size[1] + target_width = images.size[0] + + + print("attn_map",attn_map.shape)# torch.Size([4, 76, 128]) + print("target_height",target_height) #有问题 608 + print("target_width",target_width) #有问题 1024 + + attn_norm = get_similarity_map(attn_map.unsqueeze(0), (target_height, target_width), min_max=True, threshold=0.15) + print("attn_norm ",attn_norm.shape) # 有问题attn_norm torch.Size([1, 4, 448, 448]) + print('all_bpe_strings:{:}'.format(all_bpe_strings)) + indexes_without_space = torch.tensor([index for index, string in enumerate(all_bpe_strings) if ' ' is not string]) + + # Draw similarity map + # print(images_vis.shape) + images_vis = (images_vis.permute(1,2,0).cpu().numpy() * 125).astype('uint8') + for b in range(attn_norm.shape[0]): + for n in range(attn_norm.shape[1]-1): + vis = (attn_norm[b, n, :, :].float().detach().cpu().numpy() * 255).astype('uint8') + vis = cv2.applyColorMap(vis, cv2.COLORMAP_JET) + print("images_vis",images_vis.shape) + print("vis",vis.shape) + vis = images_vis * 0.5 + vis * 0.5 + vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB) + vis = cv2.resize(vis, src_iamge_size) + vis_list.append(vis) # Add each visualization to the list + + without_space_norm = attn_norm[b, indexes_without_space, :, :].max(0)[0] + space_norm = attn_norm[b, -1, :, :] + all_attn_norm = without_space_norm - space_norm + print(f'min:{all_attn_norm.min()};max:{all_attn_norm.max()}') + all_attn_norm = (all_attn_norm - all_attn_norm.min()) / (all_attn_norm.max() - all_attn_norm.min()) + all_attn_norm = (all_attn_norm.float().detach().cpu().numpy() * 255).astype('uint8') + vis = cv2.applyColorMap(all_attn_norm, cv2.COLORMAP_JET) + vis = images_vis * 0.5 + vis * 0.5 + vis = cv2.cvtColor(vis.astype('uint8'), cv2.COLOR_BGR2RGB) + vis = cv2.resize(vis, src_iamge_size) + vis_list.append(vis) # Add each visualization to the list + + return vis_list + + +def load_model_and_tokenizer_customed(checkpoint): + kwargs = {} + tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True, use_fast=False) + model = InternVLChatModel.from_pretrained( + checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, + load_in_8bit=False, load_in_4bit=False, **kwargs).eval() + del model.language_model.model.layers + del model.language_model.output + model = model.cuda() + return model, tokenizer \ No newline at end of file