zcamz commited on
Commit
ec9267c
·
verified ·
1 Parent(s): a4dc568

Upload TinyLlavaForConditionalGeneration

Browse files
config.json ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TinyLlavaForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration.TinyLlavaConfig",
7
+ "AutoModelForCausalLM": "modeling_tinyllava_elm.TinyLlavaForConditionalGeneration"
8
+ },
9
+ "cache_dir": null,
10
+ "connector_type": "mlp2x_gelu",
11
+ "hidden_size": 1280,
12
+ "ignore_index": -100,
13
+ "image_aspect_ratio": "square",
14
+ "image_token_index": -200,
15
+ "llm_model_name_or_path": "apple/OpenELM-270M-Instruct",
16
+ "model_type": "tinyllava",
17
+ "num_queries": 128,
18
+ "num_resampler_layers": 3,
19
+ "pad_token": "<unk>",
20
+ "resampler_hidden_size": 768,
21
+ "text_config": {
22
+ "_attn_implementation_autoset": true,
23
+ "_name_or_path": "apple/OpenELM-270M-Instruct",
24
+ "activation_fn_name": "swish",
25
+ "architectures": [
26
+ "OpenELMForCausalLM"
27
+ ],
28
+ "auto_map": {
29
+ "AutoConfig": "apple/OpenELM-270M-Instruct--configuration_openelm.OpenELMConfig",
30
+ "AutoModelForCausalLM": "apple/OpenELM-270M-Instruct--modeling_openelm.OpenELMForCausalLM"
31
+ },
32
+ "ffn_dim_divisor": 256,
33
+ "ffn_multipliers": [
34
+ 0.5,
35
+ 0.73,
36
+ 0.97,
37
+ 1.2,
38
+ 1.43,
39
+ 1.67,
40
+ 1.9,
41
+ 2.13,
42
+ 2.37,
43
+ 2.6,
44
+ 2.83,
45
+ 3.07,
46
+ 3.3,
47
+ 3.53,
48
+ 3.77,
49
+ 4.0
50
+ ],
51
+ "ffn_with_glu": true,
52
+ "head_dim": 64,
53
+ "max_context_length": 2048,
54
+ "model_dim": 1280,
55
+ "model_type": "openelm",
56
+ "normalization_layer_name": "rms_norm",
57
+ "normalize_qk_projections": true,
58
+ "num_gqa_groups": 4,
59
+ "num_kv_heads": [
60
+ 3,
61
+ 3,
62
+ 3,
63
+ 3,
64
+ 3,
65
+ 4,
66
+ 4,
67
+ 4,
68
+ 4,
69
+ 4,
70
+ 4,
71
+ 4,
72
+ 5,
73
+ 5,
74
+ 5,
75
+ 5
76
+ ],
77
+ "num_query_heads": [
78
+ 12,
79
+ 12,
80
+ 12,
81
+ 12,
82
+ 12,
83
+ 16,
84
+ 16,
85
+ 16,
86
+ 16,
87
+ 16,
88
+ 16,
89
+ 16,
90
+ 20,
91
+ 20,
92
+ 20,
93
+ 20
94
+ ],
95
+ "num_transformer_layers": 16,
96
+ "qkv_multipliers": [
97
+ 0.5,
98
+ 1.0
99
+ ],
100
+ "rope_freq_constant": 10000,
101
+ "rope_max_length": 4096,
102
+ "share_input_output_layers": true,
103
+ "tie_word_embeddings": true,
104
+ "torch_dtype": "float16"
105
+ },
106
+ "tokenizer_model_max_length": 2048,
107
+ "tokenizer_name_or_path": "meta-llama/Llama-2-7b-hf",
108
+ "tokenizer_padding_side": "right",
109
+ "tokenizer_use_fast": false,
110
+ "torch_dtype": "float32",
111
+ "transformers_version": "4.47.1",
112
+ "tune_type_connector": "full",
113
+ "tune_type_llm": "lora",
114
+ "tune_type_vision_tower": "frozen",
115
+ "tune_vision_tower_from_layer": 0,
116
+ "use_cache": true,
117
+ "vision_config": {
118
+ "_name_or_path": "apple/aimv2-large-patch14-224-distilled",
119
+ "architectures": [
120
+ "AIMv2Model"
121
+ ],
122
+ "auto_map": {
123
+ "AutoConfig": "apple/aimv2-large-patch14-224-distilled--configuration_aimv2.AIMv2Config",
124
+ "AutoModel": "apple/aimv2-large-patch14-224-distilled--modeling_aimv2.AIMv2Model",
125
+ "FlaxAutoModel": "apple/aimv2-large-patch14-224-distilled--modeling_flax_aimv2.FlaxAIMv2Model"
126
+ },
127
+ "image_size": 224,
128
+ "intermediate_size": 2816,
129
+ "model_name_or_path": "apple/aimv2-large-patch14-224-distilled",
130
+ "model_name_or_path2": "",
131
+ "model_type": "aimv2",
132
+ "num_attention_heads": 8,
133
+ "projection_dropout": 0.0,
134
+ "qkv_bias": false,
135
+ "rms_norm_eps": 1e-05,
136
+ "torch_dtype": "float32",
137
+ "use_bias": false
138
+ },
139
+ "vision_feature_layer": -2,
140
+ "vision_feature_select_strategy": "patch",
141
+ "vision_hidden_size": 1024,
142
+ "vision_model_name_or_path": "apple/aimv2-large-patch14-224-distilled",
143
+ "vision_model_name_or_path2": "",
144
+ "vocab_size": 32000
145
+ }
configuration.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, LlavaConfig
2
+ from transformers import CONFIG_MAPPING
3
+ from transformers import AutoConfig
4
+
5
+ IGNORE_INDEX = -100
6
+ IMAGE_TOKEN_INDEX = -200
7
+ DEFAULT_IMAGE_TOKEN = "<image>"
8
+
9
+ class TinyLlavaConfig(PretrainedConfig):
10
+
11
+ model_type = "tinyllava"
12
+ def __init__(
13
+ self,
14
+ llm_model_name_or_path = '',
15
+ tokenizer_name_or_path = None,
16
+ vision_model_name_or_path = '',
17
+ vision_model_name_or_path2 = '',
18
+ connector_type = None,
19
+ text_config=None,
20
+ hidden_size=2048,
21
+ vocab_size=32000,
22
+ ignore_index=-100,
23
+ image_token_index=32000,
24
+ pad_token = None,
25
+ pad_token_id = None,
26
+ tokenizer_padding_side = 'right',
27
+ tokenizer_model_max_length = 2048,
28
+ vision_config = None,
29
+ vision_hidden_size = None,
30
+ vision_feature_layer = -2,
31
+ vision_feature_select_strategy = 'patch',
32
+ image_aspect_ratio = 'square',
33
+ resampler_hidden_size = None,
34
+ num_queries = None,
35
+ num_resampler_layers = None,
36
+ use_cache = False,
37
+ cache_dir = None,
38
+ tokenizer_use_fast = False,
39
+ tune_type_llm = 'frozen',
40
+ tune_type_connector = 'frozen',
41
+ tune_type_vision_tower = 'frozen',
42
+ tune_vision_tower_from_layer = -1,
43
+
44
+ **kwargs
45
+
46
+ ):
47
+ self.llm_model_name_or_path = llm_model_name_or_path
48
+ self.tokenizer_name_or_path = tokenizer_name_or_path or self.llm_model_name_or_path
49
+ self.vision_model_name_or_path = vision_model_name_or_path
50
+ self.vision_model_name_or_path2 = vision_model_name_or_path2
51
+ self.connector_type = connector_type
52
+ self.tune_type_llm = tune_type_llm
53
+ self.tune_type_connector = tune_type_connector
54
+ self.tune_type_vision_tower = tune_type_vision_tower
55
+ self.tune_vision_tower_from_layer = tune_vision_tower_from_layer
56
+
57
+ self.ignore_index = IGNORE_INDEX
58
+ self.image_token_index = IMAGE_TOKEN_INDEX
59
+ self.pad_token = pad_token
60
+ self.pad_token_id = pad_token_id
61
+ self.tokenizer_padding_side = tokenizer_padding_side
62
+ self.tokenizer_model_max_length = tokenizer_model_max_length
63
+ self.vision_feature_layer = vision_feature_layer
64
+ self.vision_feature_select_strategy = vision_feature_select_strategy
65
+ self.image_aspect_ratio = image_aspect_ratio
66
+ self.resampler_hidden_size = resampler_hidden_size
67
+ self.num_queries = num_queries
68
+ self.num_resampler_layers = num_resampler_layers
69
+ self.use_cache = use_cache
70
+ self.cache_dir = cache_dir
71
+ self.tokenizer_use_fast = tokenizer_use_fast
72
+ self._load_text_config(text_config)
73
+ self._load_vision_config(vision_config)
74
+
75
+ super().__init__(**kwargs)
76
+
77
+ def load_from_config(self, config):
78
+ self.llm_model_name_or_path = getattr(config, 'model_name_or_path', '')
79
+ self.tokenizer_name_or_path = getattr(config, 'tokenizer_name_or_path', None) or self.llm_model_name_or_path
80
+ self.vision_model_name_or_path = getattr(config, 'vision_tower', '')
81
+ self.vision_model_name_or_path2 = getattr(config, 'vision_tower2', '')
82
+ self.connector_type = getattr(config, 'connector_type', None)
83
+ self.vision_feature_layer = getattr(config, 'mm_vision_select_layer', -2)
84
+ self.vision_feature_select_strategy = getattr(config, 'mm_vision_select_feature', "patch")
85
+ self.image_aspect_ratio = getattr(config, 'image_aspect_ratio', "pad")
86
+ self.resampler_hidden_size = getattr(config, 'resampler_hidden_size', None)
87
+ self.num_queries = getattr(config, 'num_queries', None)
88
+ self.num_resampler_layers = getattr(config, 'num_resampler_layers', None)
89
+
90
+ self.cache_dir = getattr(config, 'cache_dir', None)
91
+ self.tokenizer_use_fast = getattr(config, 'tokenizer_use_fast', False)
92
+ self.tokenizer_model_max_length = getattr(config, 'model_max_length', 2048)
93
+ self.tokenizer_padding_side = getattr(config, 'tokenizer_padding_side', 'right')
94
+
95
+ self._load_text_config()
96
+ self._load_vision_config()
97
+
98
+
99
+ def _load_text_config(self, text_config=None):
100
+ if self.llm_model_name_or_path is None or self.llm_model_name_or_path == '':
101
+ self.text_config = CONFIG_MAPPING['llama']()
102
+
103
+ else:
104
+ self.text_config = AutoConfig.from_pretrained(self.llm_model_name_or_path, trust_remote_code=True)
105
+ if text_config is not None:
106
+ self.text_config = self.text_config.from_dict(text_config)
107
+
108
+ self.hidden_size = getattr(self.text_config, 'hidden_size', getattr(self.text_config, 'model_dim', None))
109
+ self.vocab_size = getattr(self.text_config, 'vocab_size', None)
110
+
111
+
112
+
113
+ def _load_vision_config(self, vision_config=None):
114
+ if self.vision_model_name_or_path is None or self.vision_model_name_or_path == '':
115
+ self.vision_config = CONFIG_MAPPING['clip_vision_model'](
116
+ intermediate_size=4096,
117
+ hidden_size=1024,
118
+ patch_size=14,
119
+ image_size=336,
120
+ num_hidden_layers=24,
121
+ num_attention_heads=16,
122
+ vocab_size=32000,
123
+ projection_dim=768,
124
+ )
125
+
126
+ else:
127
+ self.vision_config = AutoConfig.from_pretrained(self.vision_model_name_or_path.split(':')[-1], trust_remote_code=True)
128
+ self.vision_config = getattr(self.vision_config, 'vision_config', self.vision_config)
129
+ if vision_config is not None:
130
+ self.vision_config = self.vision_config.from_dict(vision_config)
131
+
132
+ self.vision_config.model_name_or_path = self.vision_model_name_or_path.split(':')[-1]
133
+ self.vision_config.model_name_or_path2 = self.vision_model_name_or_path2.split(':')[-1]
134
+ self.vision_hidden_size = getattr(self.vision_config, 'hidden_size', None)
135
+
136
+
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "transformers_version": "4.47.1"
6
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0469b8698895fb3758ab656d55b5e4d1fcb3d6cc97dc255b56a08cdd6e65a09d
3
+ size 2334746616
modeling_tinyllava_elm.py ADDED
@@ -0,0 +1,1911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import dataclasses
3
+ from typing import List, Optional, Tuple, Union
4
+ import ast
5
+ import re
6
+ from enum import auto, Enum
7
+ import requests
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ import base64
11
+ import time
12
+
13
+ import torch
14
+ import torch.utils.checkpoint
15
+ from torch import nn, Tensor
16
+ from torch.nn import functional as F
17
+
18
+ from transformers import PreTrainedModel
19
+ from transformers.modeling_outputs import CausalLMOutputWithPast
20
+ from transformers.generation.utils import GenerateOutput
21
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, AutoImageProcessor
22
+
23
+ from configuration import TinyLlavaConfig, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
24
+
25
+ # from tinyllava.utils.data_utils import get_value_from_kwargs
26
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
27
+ WORKER_HEART_BEAT_INTERVAL = 15
28
+
29
+ LOGDIR = "."
30
+ import os
31
+ #
32
+ # For licensing see accompanying LICENSE file.
33
+ # Copyright (C) 2024 Apple Inc. All Rights Reserved.
34
+ #
35
+
36
+ from torch.nn import CrossEntropyLoss
37
+ from transformers.activations import ACT2FN
38
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
39
+ from transformers.modeling_outputs import (
40
+ BaseModelOutputWithPast,
41
+ )
42
+ from transformers.utils import logging
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ # this import has to be relative, otherwise, when setting trust_remote_code=True
47
+ # huggingface transformers won't be able to load the module correctly
48
+ from numbers import Number
49
+ from typing import List, Optional, Union
50
+
51
+ import numpy as np
52
+ from transformers import PretrainedConfig, AutoTokenizer
53
+
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ # Model Constants
59
+ IGNORE_INDEX = -100
60
+ IMAGE_TOKEN_INDEX = -200
61
+ DEFAULT_IMAGE_TOKEN = "<image>"
62
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
63
+ DEFAULT_IM_START_TOKEN = "<im_start>"
64
+ DEFAULT_IM_END_TOKEN = "<im_end>"
65
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
66
+
67
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
68
+ WORKER_HEART_BEAT_INTERVAL = 15
69
+ LOGDIR = "."
70
+
71
+
72
+ class SeparatorStyle(Enum):
73
+ """Different separator style."""
74
+ SINGLE = auto()
75
+ TWO = auto()
76
+ MPT = auto()
77
+ PLAIN = auto()
78
+ LLAMA_2 = auto()
79
+ TINY_LLAMA = auto()
80
+ QWEN_2 = auto()
81
+
82
+
83
+ @dataclasses.dataclass
84
+ class Conversation:
85
+ """A class that keeps all conversation history."""
86
+ system: str
87
+ roles: List[str]
88
+ messages: List[List[str]]
89
+ offset: int
90
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
91
+ sep: str = "###"
92
+ sep2: str = None
93
+ version: str = "Unknown"
94
+
95
+ skip_next: bool = False
96
+
97
+ def get_prompt(self):
98
+ messages = self.messages
99
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
100
+ messages = self.messages.copy()
101
+ init_role, init_msg = messages[0].copy()
102
+ init_msg = init_msg[0].replace("<image>", "").strip()
103
+ if 'mmtag' in self.version:
104
+ messages[0] = (init_role, init_msg)
105
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
106
+ messages.insert(1, (self.roles[1], "Received."))
107
+ else:
108
+ messages[0] = (init_role, "<image>\n" + init_msg)
109
+
110
+ if self.sep_style == SeparatorStyle.TWO:
111
+ seps = [self.sep, self.sep2]
112
+ ret = self.system + seps[0]
113
+ for i, (role, message) in enumerate(messages):
114
+ if message:
115
+ if type(message) is tuple:
116
+ message, _, _ = message
117
+ ret += role + ": " + message + seps[i % 2]
118
+ else:
119
+ ret += role + ":"
120
+ else:
121
+ raise ValueError(f"Invalid style: {self.sep_style}")
122
+
123
+ return ret
124
+
125
+ def append_message(self, role, message):
126
+ self.messages.append([role, message])
127
+
128
+ def copy(self):
129
+ return Conversation(
130
+ system=self.system,
131
+ roles=self.roles,
132
+ messages=[[x, y] for x, y in self.messages],
133
+ offset=self.offset,
134
+ sep_style=self.sep_style,
135
+ sep=self.sep,
136
+ sep2=self.sep2,
137
+ version=self.version)
138
+
139
+
140
+
141
+
142
+ conv_phi_v0 = Conversation(
143
+ system="A chat between a curious user and an artificial intelligence assistant. "
144
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
145
+ roles=("USER", "ASSISTANT"),
146
+ version="phi",
147
+ messages=(),
148
+ offset=0,
149
+ sep_style=SeparatorStyle.TWO,
150
+ sep=" ",
151
+ sep2="<|endoftext|>",
152
+ )
153
+
154
+
155
+ def load_image_from_base64(image):
156
+ return Image.open(BytesIO(base64.b64decode(image)))
157
+
158
+
159
+ def expand2square(pil_img, background_color):
160
+ width, height = pil_img.size
161
+ if width == height:
162
+ return pil_img
163
+ elif width > height:
164
+ result = Image.new(pil_img.mode, (width, width), background_color)
165
+ result.paste(pil_img, (0, (width - height) // 2))
166
+ return result
167
+ else:
168
+ result = Image.new(pil_img.mode, (height, height), background_color)
169
+ result.paste(pil_img, ((height - width) // 2, 0))
170
+ return result
171
+
172
+
173
+ def process_images(images, image_processor, model_cfg):
174
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
175
+ new_images = []
176
+ if image_aspect_ratio == 'pad':
177
+ for image in images:
178
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
179
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
180
+ new_images.append(image)
181
+ else:
182
+ return image_processor(images, return_tensors='pt')['pixel_values']
183
+ if all(x.shape == new_images[0].shape for x in new_images):
184
+ new_images = torch.stack(new_images, dim=0)
185
+ return new_images
186
+
187
+
188
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
189
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
190
+
191
+ def insert_separator(X, sep):
192
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
193
+
194
+ input_ids = []
195
+ offset = 0
196
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
197
+ offset = 1
198
+ input_ids.append(prompt_chunks[0][0])
199
+
200
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
201
+ input_ids.extend(x[offset:])
202
+
203
+ if return_tensors is not None:
204
+ if return_tensors == 'pt':
205
+ return torch.tensor(input_ids, dtype=torch.long)
206
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
207
+ return input_ids
208
+
209
+ def load_image(image_file):
210
+ if image_file.startswith("http") or image_file.startswith("https"):
211
+ response = requests.get(image_file)
212
+ image = Image.open(BytesIO(response.content)).convert("RGB")
213
+ else:
214
+ image = Image.open(image_file).convert("RGB")
215
+ return image
216
+
217
+
218
+ def make_divisible(
219
+ v: Union[float, int],
220
+ divisor: Optional[int] = 8,
221
+ min_value: Optional[Union[float, int]] = None,
222
+ ) -> Union[float, int]:
223
+ """
224
+ This function is taken from the original tf repo.
225
+ It ensures that all layers have a channel number that is divisible by the divisor
226
+ It can be seen at:
227
+ https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
228
+ Args:
229
+ v: input value
230
+ divisor: default to 8
231
+ min_value: minimum divisor value
232
+ Returns:
233
+ new_v: new divisible value
234
+ """
235
+ if min_value is None:
236
+ min_value = divisor
237
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
238
+ # Make sure that round down does not go down by more than 10%.
239
+ if new_v < 0.9 * v:
240
+ new_v += divisor
241
+ return new_v
242
+
243
+
244
+ def compute_heads(model_dim: int, head_dim: int) -> int:
245
+ """Compute the number of heads.
246
+ Args:
247
+ model_dim: Model dimension.
248
+ head_dim: Head dimension.
249
+ Returns:
250
+ An integer denoting number of heads in multi-head attention is returned.
251
+ Raises:
252
+ ValueError: if model dimension is not divisible by head dimension.
253
+ """
254
+ if model_dim % head_dim == 0:
255
+ return model_dim // head_dim
256
+ else:
257
+ raise ValueError(
258
+ f"Model dimension should be divisible by head dimension. Got: {model_dim} and {head_dim}."
259
+ )
260
+
261
+
262
+ OpenELM_CONFIGS = {
263
+ "OpenELM-270M": dict(
264
+ num_transformer_layers=16,
265
+ model_dim=1280,
266
+ head_dim=64,
267
+ num_gqa_groups=4,
268
+ normalize_qk_projections=True,
269
+ share_input_output_layers=True,
270
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
271
+ ffn_multipliers=(0.5, 4.0),
272
+ qkv_multipliers=(0.5, 1.0),
273
+ ),
274
+ "OpenELM-450M": dict(
275
+ num_transformer_layers=20,
276
+ model_dim=1536,
277
+ head_dim=64,
278
+ num_gqa_groups=4,
279
+ normalize_qk_projections=True,
280
+ share_input_output_layers=True,
281
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
282
+ ffn_multipliers=(0.5, 4.0),
283
+ qkv_multipliers=(0.5, 1.0),
284
+ ),
285
+ "OpenELM-1_1B": dict(
286
+ num_transformer_layers=28,
287
+ model_dim=2048,
288
+ head_dim=64,
289
+ num_gqa_groups=4,
290
+ normalize_qk_projections=True,
291
+ share_input_output_layers=True,
292
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
293
+ ffn_multipliers=(0.5, 4.0),
294
+ qkv_multipliers=(0.5, 1.0),
295
+ ),
296
+ "OpenELM-3B": dict(
297
+ num_transformer_layers=36,
298
+ model_dim=3072,
299
+ head_dim=128,
300
+ num_gqa_groups=4,
301
+ normalize_qk_projections=True,
302
+ share_input_output_layers=True,
303
+ # Vary the FFN and QKV multipliers to create variable FFN and attention layers respectively.
304
+ ffn_multipliers=(0.5, 4.0),
305
+ qkv_multipliers=(0.5, 1.0),
306
+ ),
307
+ }
308
+
309
+
310
+ class OpenELMConfig(PretrainedConfig):
311
+ r"""
312
+ This is the configuration class to store the configuration of a [`OpenELMModel`]. It is used to instantiate an OpenELM model according to the specified arguments, defining the model architecture.
313
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
314
+ documentation from [`PretrainedConfig`] for more information.
315
+ Args:
316
+ vocab_size (`int`, *optional*, defaults to 32000):
317
+ Vocabulary size of the OpenELM model.
318
+ max_context_length (`int`, *optional*, defaults to 2048):
319
+ Maximum number of input tokens.
320
+ num_transformer_layers (`int`, *optional*, defaults to 12):
321
+ Number of hidden layers in the Transformer decoder.
322
+ model_dim (`int`, *optional*, defaults to 2048):
323
+ Dimension of the hidden representations.
324
+ head_dim (`int`, *optional*, defaults to 128):
325
+ The attention head dimension.
326
+ qkv_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 1.0):
327
+ If the qkv_multipliers is a Number, then all attention layers have the same latent dimensions,
328
+ resulting in uniform allocation of parameters.
329
+ If the qkv_multipliers is a List of Number, then each attention layer have different latent dimensions
330
+ assuming qkv_multipliers[0] != qkv_multipliers[1]. This results in variable allocation of parameters in attention layer.
331
+ This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
332
+ num_query_heads (`Union[int, None]`, *optional*, defaults to None):
333
+ The number of query heads, computed from `compute_heads(model_dim=model_dim, head_dim=head_dim)`.
334
+ num_gqa_groups (`int`, *optional*, defaults to 1):
335
+ This variable allows to switch between multi-head attention, group query attention, and multi-query attention.
336
+ When num_gqa_groups == 1, then it is multi-head attention.
337
+ When 1 < num_gqa_groups < num_heads and num_heads is divisible by num_gqa_groups, then it is group query attention
338
+ When num_gqa_groups == num_heads, then it is multi-query attention
339
+ ffn_multipliers (`Union[Number, List[Number]]`, *optional*, defaults to 4.0):
340
+ Feed-forward network (FFN) multipliers.
341
+ If the ffn_multipliers is a Number, then all FFN layers have the same latent dimensions,
342
+ resulting in uniform allocation of parameters.
343
+ If the ffn_multipliers is a List of Number, then each FFN layer have different latent dimensions
344
+ assuming ffn_multipliers[0] != ffn_multipliers[1]. This results in variable allocation of parameters in FFN layer.
345
+ This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
346
+ ffn_with_glu (`bool`, *optional*, defaults to True):
347
+ Whether to use FFN with Gated Linear Unit (GLU)
348
+ ffn_dim_divisor (`int`, *optional*, defaults to 256):
349
+ The ffn layer dimension divisor.
350
+ activation_fn_name (`str` or `function`, *optional*, defaults to `"swish"`):
351
+ The non-linear activation function (function or string) in the decoder.
352
+ normalization_layer_name (`str` or `function`, *optional*, defaults to `"rms_norm"`):
353
+ Type of normalization layer.
354
+ normalize_qk_projections (`bool`, *optional*, defaults to False):
355
+ Whether to normalize queries and keys after projections
356
+ share_input_output_layers (`bool`, *optional*, defaults to False):
357
+ Whether to share the embedding between input and output linear layer
358
+ rope_freq_constant (`int`, *optional*, defaults to 10000):
359
+ The base period of the RoPE embeddings.
360
+ rope_max_length (`int`, *optional*, defaults to 4096):
361
+ That rope_max_length is set to twice of max_context_length.
362
+ This allows flexibility in token lengths during training or fine-tuning.
363
+ initializer_range (`float`, *optional*, defaults to 0.02):
364
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
365
+ use_cache (`bool`, *optional*, defaults to `True`):
366
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
367
+ relevant if `config.is_decoder=True`.
368
+ bos_token_id (`int`, *optional*, defaults to 2):
369
+ Beginning of stream token id.
370
+ eos_token_id (`int`, *optional*, defaults to 1):
371
+ End of stream token id.
372
+ """
373
+
374
+ model_type = "openelm"
375
+
376
+ def __init__(
377
+ self,
378
+ vocab_size: int = 32000,
379
+ max_context_length: int = 2048,
380
+ num_transformer_layers: int = 12,
381
+ model_dim: int = 2048,
382
+ head_dim: int = 128,
383
+ qkv_multipliers: Union[Number, List[Number]] = 1.0,
384
+ num_query_heads: Union[int, None] = None,
385
+ num_gqa_groups: int = 1,
386
+ ffn_multipliers: Union[Number, List[Number]] = 4.0,
387
+ ffn_with_glu: bool = True,
388
+ ffn_dim_divisor: int = 256,
389
+ activation_fn_name: str = "swish",
390
+ normalization_layer_name: str = "rms_norm",
391
+ normalize_qk_projections: bool = False,
392
+ share_input_output_layers: bool = False,
393
+ rope_freq_constant: int = 10000,
394
+ rope_max_length: int = 4096,
395
+ initializer_range: float = 0.02,
396
+ use_cache: bool = True,
397
+ bos_token_id: int = 1,
398
+ eos_token_id: int = 2,
399
+ **kwargs,
400
+ ) -> None:
401
+ self.vocab_size = vocab_size
402
+ self.max_context_length = max_context_length
403
+ self.num_transformer_layers = num_transformer_layers
404
+ self.model_dim = model_dim
405
+ self.head_dim = head_dim
406
+ self.qkv_multipliers = qkv_multipliers
407
+ self.num_query_heads = num_query_heads
408
+ self.num_gqa_groups = num_gqa_groups
409
+ self.ffn_multipliers = ffn_multipliers
410
+ self.ffn_with_glu = ffn_with_glu
411
+ self.ffn_dim_divisor = ffn_dim_divisor
412
+ self.activation_fn_name = activation_fn_name
413
+ self.normalization_layer_name = normalization_layer_name
414
+ self.normalize_qk_projections = normalize_qk_projections
415
+ self.share_input_output_layers = share_input_output_layers
416
+ self.rope_freq_constant = rope_freq_constant
417
+ self.rope_max_length = rope_max_length
418
+ self.num_query_heads = (
419
+ compute_heads(model_dim=model_dim, head_dim=head_dim)
420
+ if num_query_heads is None
421
+ else num_query_heads
422
+ )
423
+ self.initializer_range = initializer_range
424
+
425
+ self.__post_init__()
426
+ super().__init__(
427
+ use_cache=use_cache,
428
+ bos_token_id=bos_token_id,
429
+ eos_token_id=eos_token_id,
430
+ **kwargs,
431
+ )
432
+
433
+ def __post_init__(self) -> None:
434
+ if self.num_gqa_groups is not None:
435
+ head_multiple_of = self.num_gqa_groups
436
+ else:
437
+ head_multiple_of = 2
438
+
439
+ if isinstance(self.qkv_multipliers, Number):
440
+ # All attention layers have the same latent dimensions, resulting in uniform allocation of parameters.
441
+ qkv_dim = make_divisible(
442
+ self.model_dim * self.qkv_multipliers,
443
+ divisor=self.head_dim * head_multiple_of,
444
+ )
445
+ query_dims = [int(qkv_dim)] * self.num_transformer_layers
446
+
447
+ elif (
448
+ isinstance(self.qkv_multipliers, (tuple, list))
449
+ and len(self.qkv_multipliers) == 2
450
+ ):
451
+ # Each attention layer have different latent dimensions assuming qkv_multipliers[0] != qkv_multipliers[1].
452
+ # This results in variable allocation of parameters in attention layer.
453
+ # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
454
+ qkv_multipliers = [
455
+ round(v, 2)
456
+ for v in np.linspace(
457
+ self.qkv_multipliers[0],
458
+ self.qkv_multipliers[1],
459
+ num=self.num_transformer_layers,
460
+ dtype=float,
461
+ )
462
+ ]
463
+ # Make sure that scaled model dimension is divisible by scaled head dimension.
464
+ query_dims = [
465
+ int(
466
+ make_divisible(
467
+ self.model_dim * m, divisor=self.head_dim * head_multiple_of
468
+ )
469
+ )
470
+ for m in qkv_multipliers
471
+ ]
472
+ else:
473
+ raise NotImplementedError(
474
+ f"QKV multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
475
+ )
476
+
477
+ # compute the number of query, key, and value heads
478
+ # For multi-head and multi-query attention, the number of heads for query, key, and value are the same.
479
+ # For group query attention, the number of key and value heads are the same.
480
+ self.num_query_heads = [
481
+ int(compute_heads(q_dim, self.head_dim)) for q_dim in query_dims
482
+ ]
483
+ self.num_kv_heads = [
484
+ q_heads // self.num_gqa_groups for q_heads in self.num_query_heads
485
+ ]
486
+
487
+ # Feed-forward network (FFN) multipliers
488
+ if isinstance(self.ffn_multipliers, Number):
489
+ # All FFN layers have the same latent dimensions, resulting in uniform allocation of parameters.
490
+ self.ffn_multipliers = [self.ffn_multipliers] * self.num_transformer_layers
491
+ elif isinstance(self.ffn_multipliers, (tuple, list)):
492
+ # Each FFN layer have different latent dimensions assuming ffn_multipliers[0] != ffn_multipliers[1].
493
+ # This results in variable allocation of parameters in FFN layer.
494
+ # This scaling is known as layer-wise or block-wise scaling: https://arxiv.org/abs/2008.00623
495
+ if len(self.ffn_multipliers) == 2:
496
+ self.ffn_multipliers = [
497
+ round(v, 2)
498
+ for v in np.linspace(
499
+ self.ffn_multipliers[0],
500
+ self.ffn_multipliers[1],
501
+ num=self.num_transformer_layers,
502
+ dtype=float,
503
+ )
504
+ ]
505
+ else:
506
+ assert (
507
+ len(self.ffn_multipliers) == self.num_transformer_layers
508
+ ), f"{len(self.ffn_multipliers)=}!={self.num_transformer_layers=}"
509
+ else:
510
+ raise NotImplementedError(
511
+ f"FFN multipliers should be a single number or a list containing exactly two numbers. Got: {qkv_multipliers}."
512
+ )
513
+
514
+ # check num_query_heads divisible by num_kv_heads for every layer
515
+ for layer_idx in range(len(query_dims)):
516
+ assert self.num_query_heads[layer_idx] % self.num_kv_heads[layer_idx] == 0
517
+
518
+ class OpenELMRMSNorm(nn.Module):
519
+ def __init__(self, num_features: int, eps: float = 1e-6):
520
+ """
521
+ Initialize the OpenELMRMSNorm normalization layer.
522
+ Args:
523
+ dim (int): The dimension of the input tensor.
524
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
525
+ Attributes:
526
+ eps (float): A small value added to the denominator for numerical stability.
527
+ weight (nn.Parameter): Learnable scaling parameter.
528
+ """
529
+ super().__init__()
530
+ self.eps = eps
531
+ self.weight = nn.Parameter(torch.ones(num_features))
532
+ self.num_features = num_features
533
+
534
+ def _norm(self, x: Tensor) -> Tensor:
535
+ """
536
+ Apply the OpenELMRMSNorm normalization to the input tensor.
537
+ Args:
538
+ x (torch.Tensor): The input tensor.
539
+ Returns:
540
+ torch.Tensor: The normalized tensor.
541
+ """
542
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
543
+
544
+ def forward(self, x: Tensor) -> Tensor:
545
+ """
546
+ Forward pass through the OpenELMRMSNorm layer.
547
+ Args:
548
+ x (torch.Tensor): The input tensor.
549
+ Returns:
550
+ torch.Tensor: The output tensor after applying OpenELMRMSNorm.
551
+ """
552
+ output = self._norm(x.float()).type_as(x)
553
+ return output * self.weight
554
+
555
+ def extra_repr(self) -> str:
556
+ return (
557
+ super().extra_repr() + f"num_features={self.num_features}, eps={self.eps}"
558
+ )
559
+
560
+
561
+ class OpenELMPreTrainedModel(PreTrainedModel):
562
+ config_class = OpenELMConfig
563
+ base_model_prefix = "transformer"
564
+ supports_gradient_checkpointing = True
565
+ _no_split_modules = ["OpenELMDecoderLayer"]
566
+ _skip_keys_device_placement = "past_key_values"
567
+
568
+ def __init__(self, *inputs, **kwargs) -> None:
569
+ super().__init__(*inputs, **kwargs)
570
+
571
+ def _init_weights(self, module: nn.Module) -> None:
572
+ """Initialize the weights."""
573
+ if isinstance(module, nn.Linear):
574
+ # Slightly different from the TF version which uses truncated_normal for initialization
575
+ # cf https://github.com/pytorch/pytorch/pull/5617
576
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
577
+ if module.bias is not None:
578
+ module.bias.data.zero_()
579
+ elif isinstance(module, nn.Embedding):
580
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
581
+ if module.padding_idx is not None:
582
+ module.weight.data[module.padding_idx].zero_()
583
+ elif isinstance(module, OpenELMRMSNorm):
584
+ module.weight.data.fill_(1.0)
585
+
586
+
587
+ def _rotate_half(x: Tensor) -> Tensor:
588
+ x1, x2 = x.chunk(2, dim=-1)
589
+ return torch.cat((-x2, x1), dim=-1)
590
+
591
+
592
+ def _apply_rotary_pos_emb(x: Tensor, pos_sin: Tensor, pos_cos: Tensor) -> Tensor:
593
+ return (x * pos_cos) + (_rotate_half(x) * pos_sin)
594
+
595
+
596
+ class OpenELMRotaryEmbedding(torch.nn.Module):
597
+ """
598
+ The rotary position embeddings (aka RoPE) from `RoFormer <https://arxiv.org/abs/2104.09864>`_.
599
+ RoPE encodes the position information of tokens using a rotation matrix, and is able to capture
600
+ explicit relative positional dependencies.
601
+ Args:
602
+ model_dim: The dimensionality of the model's hidden state.
603
+ max_seq_length: Maximum sequence length.
604
+ freq_constant: A constant used for computing frequencies.
605
+ """
606
+
607
+ def __init__(
608
+ self, model_dim: int, max_seq_length: int, freq_constant: int = 10000
609
+ ) -> None:
610
+ inv_freq = 1.0 / (
611
+ freq_constant
612
+ ** (torch.arange(0, model_dim, 2, dtype=torch.float32) / model_dim)
613
+ )
614
+ super().__init__()
615
+
616
+ self.model_dim = model_dim
617
+ self.freq_constant = freq_constant
618
+ self.max_seq_length = max_seq_length
619
+
620
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
621
+ self._cached_cos = None
622
+ self._cached_sin = None
623
+ self._cached_seq_length = max_seq_length
624
+ self._compute_sin_cos_embeddings(max_seq_length)
625
+
626
+ def extra_repr(self) -> str:
627
+ return f"\tmodel_dim={self.model_dim}, max_seq_length={self.max_seq_length}, freq_constant={self.freq_constant}"
628
+
629
+ def _compute_sin_cos_embeddings(
630
+ self,
631
+ key_len: int,
632
+ key_device: torch.device = torch.device("cpu"),
633
+ key_dtype: torch.dtype = torch.float32,
634
+ ) -> None:
635
+ """
636
+ Compute sine and cos embeddings.
637
+ Args:
638
+ key_len: Number of tokens in the key embeddings in the transformer model.
639
+ device: Device where the key embeddings are stored.
640
+ key_dtype: Data type of the key embeddings.
641
+ Returns:
642
+ None
643
+ ...note:
644
+ We recalculate the sine and cosine embeddings if any of the following conditions are met:
645
+ 1. The number of tokens in key embeddings are greater than the cached sequence length.
646
+ 2. Sine and cosine caches are empty.
647
+ 3. The device and data type of sine and cosine embeddings does not match with the key embeddings.
648
+ """
649
+ if (
650
+ key_len > self._cached_seq_length
651
+ or self._cached_cos is None
652
+ or (self._cached_cos is not None and self._cached_cos.device != key_device)
653
+ or (self._cached_cos is not None and self._cached_cos.dtype != key_dtype)
654
+ or self._cached_sin is None
655
+ or (self._cached_sin is not None and self._cached_sin.device != key_device)
656
+ or (self._cached_sin is not None and self._cached_sin.dtype != key_dtype)
657
+ ):
658
+ self._cached_seq_length = max(key_len, self._cached_seq_length)
659
+
660
+ # The shape of 'pos_index' is [number of key tokens]
661
+ pos_index = torch.arange(
662
+ self._cached_seq_length,
663
+ dtype=torch.float32,
664
+ device=self.inv_freq.device,
665
+ )
666
+ # The shape of 'pos_index_theta' is [number of key tokens, model dimension]
667
+ pos_index_theta = torch.einsum("i,j->ij", pos_index, self.inv_freq)
668
+ # The shape of 'emb' is [number of key tokens, model dimension]
669
+ emb = torch.cat((pos_index_theta, pos_index_theta), dim=-1)
670
+
671
+ # the shape of cos and sin embeddings is [number of key tokens, model_dim]
672
+ cos_emb = emb.cos().to(dtype=key_dtype, device=key_device)
673
+ sin_emb = emb.sin().to(dtype=key_dtype, device=key_device)
674
+
675
+ # the shape of cached cos and sin embeddings is [1, 1, number of key tokens, model_dim]
676
+ self._cached_cos = cos_emb[None, None, :, :]
677
+ self._cached_sin = sin_emb[None, None, :, :]
678
+
679
+ def forward(
680
+ self,
681
+ query: torch.Tensor,
682
+ key: torch.Tensor,
683
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
684
+ """
685
+ The forward function of RoPE embeddings.
686
+ Args:
687
+ query: Query embeddings in the transformer model. The shape of query embeddings is
688
+ [Batch, number of query heads, number of query tokens, model dimension].
689
+ key: Key embeddings in the transformer model. The shape of key embeddings is
690
+ [Batch, number of key heads, number of key tokens, model dimension].
691
+ Returns:
692
+ A tuple containing the query and key embeddings with positional information. The shape of the returned query
693
+ and key embeddings is the same as the input query and key embeddings respectively.
694
+ ...note:
695
+ The RoPE embedding computation is done in full-precision. After the computation, input query and key tensors
696
+ are casted to original input datatype.
697
+ """
698
+ dim = key.shape[-1]
699
+ key_len = key.shape[2]
700
+ query_len = query.shape[2]
701
+
702
+ assert dim == self.model_dim
703
+ assert key.device == query.device
704
+ assert key.dtype == query.dtype
705
+
706
+ # In the context of self-attention, the lengths of keys and queries are equal.
707
+ # However, in generation tasks, such as predicting the next token in a sequence, the lengths of keys and queries
708
+ # can differ. For instance, when employing key-value (KV) caching for sequence prediction, the keys
709
+ # represent embeddings of previous tokens and the current token, while the query corresponds
710
+ # to the embedding of the current token only.
711
+ assert (
712
+ key_len >= query_len
713
+ ), "Number of keys has to be greater than or equal to number of queries."
714
+
715
+ query_float = query.float()
716
+ key_float = key.float()
717
+
718
+ self._compute_sin_cos_embeddings(
719
+ key_len, key_device=key_float.device, key_dtype=key_float.dtype
720
+ )
721
+ query_float = _apply_rotary_pos_emb(
722
+ x=query_float,
723
+ pos_sin=self._cached_sin[..., key_len - query_len : key_len, :],
724
+ pos_cos=self._cached_cos[..., key_len - query_len : key_len, :],
725
+ )
726
+ key_float = _apply_rotary_pos_emb(
727
+ x=key_float,
728
+ pos_sin=self._cached_sin[..., :key_len, :],
729
+ pos_cos=self._cached_cos[..., :key_len, :],
730
+ )
731
+
732
+ return query_float.type_as(query), key_float.type_as(key)
733
+
734
+
735
+ class OpenELMMultiHeadCausalAttention(nn.Module):
736
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
737
+ super().__init__()
738
+ self.layer_idx = layer_idx
739
+ head_dim = config.head_dim
740
+ q_heads = config.num_query_heads[layer_idx]
741
+ k_heads = config.num_kv_heads[layer_idx]
742
+ v_heads = config.num_kv_heads[layer_idx]
743
+
744
+ self.qkv_proj = nn.Linear(
745
+ in_features=config.model_dim,
746
+ out_features=(q_heads + k_heads + v_heads) * head_dim,
747
+ bias=False,
748
+ )
749
+
750
+ self.pos_embedding = OpenELMRotaryEmbedding(
751
+ model_dim=config.head_dim,
752
+ max_seq_length=config.rope_max_length,
753
+ freq_constant=config.rope_freq_constant,
754
+ )
755
+
756
+ if config.normalize_qk_projections:
757
+ self.q_norm = OpenELMRMSNorm(
758
+ num_features=config.head_dim,
759
+ )
760
+ self.k_norm = OpenELMRMSNorm(
761
+ num_features=config.head_dim,
762
+ )
763
+ else:
764
+ self.q_norm = None
765
+ self.k_norm = None
766
+
767
+ self.out_proj = nn.Linear(
768
+ in_features=q_heads * head_dim,
769
+ out_features=config.model_dim,
770
+ bias=False,
771
+ )
772
+
773
+ self.head_dim = config.head_dim
774
+ self.num_q_heads = q_heads
775
+ self.num_k_heads = k_heads
776
+ self.num_v_heads = v_heads
777
+ self.transformer_dim = config.model_dim
778
+ self.num_groups = self.num_q_heads // self.num_k_heads
779
+
780
+ def extra_repr(self) -> str:
781
+ return (
782
+ super().extra_repr()
783
+ + f"query_heads={self.num_q_heads}, key_heads={self.num_k_heads}, value_heads={self.num_v_heads}"
784
+ )
785
+
786
+ def forward(
787
+ self,
788
+ hidden_states: torch.Tensor,
789
+ attention_mask: Optional[torch.Tensor] = None,
790
+ past_key_value: Optional[Cache] = None,
791
+ output_attentions: bool = False,
792
+ use_cache: bool = False,
793
+ cache_position: Optional[torch.LongTensor] = None,
794
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
795
+ """
796
+ Forward pass of multi-head self-attention.
797
+ Args:
798
+ hidden_states: Input tensor of the shape [batch size, sequence length, model dimension].
799
+ past_key_value: Tensor storing the cached keys and values.
800
+ output_attentions: output attention weights.
801
+ use_cache: Specifies whether to use kv-cache for generation.
802
+ cache_position: used for updating the kv-cache.
803
+ Returns:
804
+ The output of the same shape as the input, optionally with a tensor containing cached keys and values.
805
+ """
806
+
807
+ # scaled_dot_product_attention does not return attention weights, set output_attentions to False
808
+ output_attentions = False
809
+ batch_size, seq_length, d_model = hidden_states.size()
810
+
811
+ # [B, S, d] --> [B, S, (q_h + k_h + v_h) * h]
812
+ qkv = self.qkv_proj(hidden_states)
813
+ # [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h]
814
+ qkv = qkv.reshape(
815
+ batch_size,
816
+ seq_length,
817
+ self.num_q_heads + self.num_k_heads + self.num_v_heads,
818
+ self.head_dim,
819
+ )
820
+ # [B, S, (q_h + k_h + v_h), h] --> [B, (q_h + k_h + v_h), S, h]
821
+ qkv = qkv.transpose(1, 2)
822
+ # [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h]
823
+ queries, keys, values = qkv.split(
824
+ [self.num_q_heads, self.num_k_heads, self.num_v_heads], dim=1
825
+ )
826
+
827
+ if self.q_norm is not None:
828
+ queries = self.q_norm(queries)
829
+
830
+ if self.k_norm is not None:
831
+ keys = self.k_norm(keys)
832
+
833
+ past_key_value = getattr(self, "past_key_value", past_key_value)
834
+
835
+ if past_key_value is not None:
836
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
837
+ # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
838
+ cache_kwargs = {"cache_position": cache_position}
839
+ keys, values = past_key_value.update(
840
+ keys, values, self.layer_idx, cache_kwargs
841
+ )
842
+
843
+ # Add positional embedding
844
+ queries, keys = self.pos_embedding(queries, keys)
845
+
846
+ if self.num_groups != 1:
847
+ # GQA
848
+ # [B, k_h, S, h] --> [B, q_h, S, h]
849
+ keys = keys.repeat_interleave(self.num_groups, dim=1)
850
+ # [B, v_h, S, h] --> [B, q_h, S, h]
851
+ values = values.repeat_interleave(self.num_groups, dim=1)
852
+
853
+ causal_mask = attention_mask
854
+ if attention_mask is not None and cache_position is not None:
855
+ causal_mask = causal_mask[:, :, cache_position, : keys.shape[-2]]
856
+
857
+ attn_output = F.scaled_dot_product_attention(
858
+ queries,
859
+ keys,
860
+ values,
861
+ attn_mask=causal_mask,
862
+ dropout_p=0,
863
+ )
864
+
865
+ attn_output = attn_output.transpose(1, 2).contiguous()
866
+ attn_output = attn_output.reshape(
867
+ batch_size, seq_length, self.num_q_heads * self.head_dim
868
+ )
869
+ attn_output = self.out_proj(attn_output)
870
+ if not output_attentions:
871
+ attn_weights = None
872
+ return attn_output, attn_weights, past_key_value
873
+
874
+
875
+ class OpenELMFeedForwardNetwork(nn.Module):
876
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
877
+ super().__init__()
878
+ ffn_multiplier = config.ffn_multipliers[layer_idx]
879
+ intermediate_dim = int(
880
+ make_divisible(
881
+ ffn_multiplier * config.model_dim,
882
+ divisor=config.ffn_dim_divisor,
883
+ )
884
+ )
885
+ if config.ffn_with_glu:
886
+ # FFN with Gated linear unit, as described in https://arxiv.org/abs/2002.05202v1.
887
+ self.proj_1 = nn.Linear(
888
+ in_features=config.model_dim,
889
+ out_features=2 * intermediate_dim,
890
+ bias=False,
891
+ )
892
+ self.proj_2 = nn.Linear(
893
+ in_features=intermediate_dim,
894
+ out_features=config.model_dim,
895
+ bias=False,
896
+ )
897
+ self.ffn_with_glu = True
898
+ else:
899
+ # Standard FFN, as described in https://arxiv.org/abs/1706.03762
900
+ self.proj_1 = nn.Linear(
901
+ in_features=config.model_dim,
902
+ out_features=intermediate_dim,
903
+ bias=False,
904
+ )
905
+ self.proj_2 = nn.Linear(
906
+ in_features=intermediate_dim,
907
+ out_features=config.model_dim,
908
+ bias=False,
909
+ )
910
+ self.ffn_with_glu = False
911
+
912
+ self.act = ACT2FN[config.activation_fn_name]
913
+
914
+ def extra_repr(self) -> str:
915
+ return super().extra_repr() + f"(ffn_with_glu) : {self.ffn_with_glu}"
916
+
917
+ def forward(self, x: Tensor) -> Tensor:
918
+ """Forward function of FFN layer.
919
+ Args:
920
+ x: Input tensor of the shape [batch size, sequence length, model dimension].
921
+ Returns:
922
+ A tensor of the same shape as the input.
923
+ """
924
+ if self.ffn_with_glu:
925
+ y_12 = self.proj_1(x)
926
+ y_1, y_2 = y_12.chunk(2, dim=-1)
927
+ y = self.act(y_1) * y_2
928
+ return self.proj_2(y)
929
+ else:
930
+ return self.proj_2(self.act(self.proj_1(x)))
931
+
932
+
933
+ class OpenELMDecoderLayer(nn.Module):
934
+ def __init__(self, config: OpenELMConfig, layer_idx: int) -> None:
935
+ super().__init__()
936
+ self.attn = OpenELMMultiHeadCausalAttention(config=config, layer_idx=layer_idx)
937
+ self.ffn = OpenELMFeedForwardNetwork(config=config, layer_idx=layer_idx)
938
+ self.ffn_norm = OpenELMRMSNorm(
939
+ num_features=config.model_dim,
940
+ )
941
+ self.attn_norm = OpenELMRMSNorm(
942
+ num_features=config.model_dim,
943
+ )
944
+
945
+ def forward(
946
+ self,
947
+ hidden_states: torch.Tensor,
948
+ attention_mask: Optional[torch.Tensor] = None,
949
+ position_ids: Optional[torch.LongTensor] = None,
950
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
951
+ output_attentions: Optional[bool] = False,
952
+ use_cache: Optional[bool] = False,
953
+ cache_position: Optional[torch.LongTensor] = None,
954
+ **kwargs,
955
+ ) -> Tuple[
956
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
957
+ ]:
958
+ """
959
+ Args:
960
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
961
+ attention_mask (`torch.FloatTensor`, *optional*):
962
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
963
+ query_sequence_length, key_sequence_length)` if default attention is used.
964
+ output_attentions (`bool`, *optional*):
965
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
966
+ returned tensors for more detail.
967
+ use_cache (`bool`, *optional*):
968
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
969
+ (see `past_key_values`).
970
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
971
+ """
972
+ residual = hidden_states
973
+ hidden_states = self.attn_norm(hidden_states)
974
+
975
+ # Self Attention
976
+ hidden_states, self_attn_weights, present_key_value = self.attn(
977
+ hidden_states=hidden_states,
978
+ attention_mask=attention_mask,
979
+ past_key_value=past_key_value,
980
+ output_attentions=output_attentions,
981
+ use_cache=use_cache,
982
+ cache_position=cache_position,
983
+ **kwargs,
984
+ )
985
+ hidden_states = residual + hidden_states
986
+
987
+ # Fully Connected
988
+ residual = hidden_states
989
+ hidden_states = self.ffn_norm(hidden_states)
990
+ hidden_states = self.ffn(hidden_states)
991
+ hidden_states = residual + hidden_states
992
+
993
+ outputs = (hidden_states,)
994
+
995
+ if output_attentions:
996
+ outputs += (self_attn_weights,)
997
+
998
+ if use_cache:
999
+ outputs += (present_key_value,)
1000
+
1001
+ return outputs
1002
+
1003
+
1004
+ class OpenELMModel(OpenELMPreTrainedModel):
1005
+ config_class = OpenELMConfig
1006
+
1007
+ def __init__(self, config: OpenELMConfig):
1008
+ super().__init__(config)
1009
+ self.config = config
1010
+
1011
+ self.token_embeddings = nn.Embedding(
1012
+ embedding_dim=config.model_dim,
1013
+ num_embeddings=config.vocab_size,
1014
+ )
1015
+
1016
+ self.layers = nn.ModuleList(
1017
+ OpenELMDecoderLayer(config=config, layer_idx=layer_idx)
1018
+ for layer_idx in range(config.num_transformer_layers)
1019
+ )
1020
+ self.norm = OpenELMRMSNorm(num_features=config.model_dim)
1021
+ if config.share_input_output_layers:
1022
+ self.classifier = None
1023
+ else:
1024
+ self.classifier = nn.Linear(
1025
+ in_features=config.model_dim,
1026
+ out_features=config.vocab_size,
1027
+ bias=False,
1028
+ )
1029
+ self.num_transformer_layers = config.num_transformer_layers
1030
+ self.gradient_checkpointing = False
1031
+
1032
+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
1033
+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_context_length`.
1034
+ causal_mask = torch.full(
1035
+ (config.max_context_length, config.max_context_length),
1036
+ fill_value=True,
1037
+ dtype=torch.bool,
1038
+ )
1039
+ self.register_buffer(
1040
+ "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
1041
+ )
1042
+
1043
+ # Initialize weights and apply final processing
1044
+ self.post_init()
1045
+ self.reset_parameters(config=config)
1046
+
1047
+ def get_input_embeddings(self):
1048
+ return self.token_embeddings
1049
+
1050
+ def set_input_embeddings(self, new_embeddings: torch.Tensor):
1051
+ self.token_embeddings = new_embeddings
1052
+
1053
+ def reset_parameters(self, config: OpenELMConfig) -> None:
1054
+ """Initialize the layers in Language Model
1055
+ The initialization scheme is followed, following `OPT <https://arxiv.org/pdf/2205.01068.pdf>`_.
1056
+ Args:
1057
+ use_megatron_std: Use standard deviation as described in Megatron-LM.
1058
+ Returns:
1059
+ None
1060
+ """
1061
+ for module in self.modules():
1062
+ if isinstance(module, nn.Linear):
1063
+ std = module.in_features**-0.5
1064
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
1065
+ if module.bias is not None:
1066
+ torch.nn.init.zeros_(module.bias)
1067
+ elif isinstance(module, nn.Embedding):
1068
+ std = module.embedding_dim**-0.5
1069
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
1070
+ elif isinstance(module, OpenELMRMSNorm):
1071
+ if module.weight is not None:
1072
+ torch.nn.init.ones_(module.weight)
1073
+ if hasattr(module, "bias") and module.bias is not None:
1074
+ torch.nn.init.zeros_(module.bias)
1075
+
1076
+ model_dim = config.model_dim
1077
+ n_layers = config.num_transformer_layers
1078
+ std = (model_dim**-0.5) * ((2 * n_layers) ** -0.5)
1079
+ for param_name, param in self.named_parameters():
1080
+ if param_name.endswith("out_proj.weight") or param_name.endswith(
1081
+ "ffn.proj_2.weight"
1082
+ ):
1083
+ torch.nn.init.normal_(param, mean=0.0, std=std)
1084
+
1085
+ def forward(
1086
+ self,
1087
+ input_ids: torch.LongTensor = None,
1088
+ attention_mask: Optional[torch.Tensor] = None,
1089
+ position_ids: Optional[torch.LongTensor] = None,
1090
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1091
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1092
+ use_cache: Optional[bool] = None,
1093
+ output_attentions: Optional[bool] = None,
1094
+ output_hidden_states: Optional[bool] = None,
1095
+ return_dict: Optional[bool] = None,
1096
+ cache_position: Optional[torch.LongTensor] = None,
1097
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1098
+ output_attentions = (
1099
+ output_attentions
1100
+ if output_attentions is not None
1101
+ else self.config.output_attentions
1102
+ )
1103
+ output_hidden_states = (
1104
+ output_hidden_states
1105
+ if output_hidden_states is not None
1106
+ else self.config.output_hidden_states
1107
+ )
1108
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1109
+ return_dict = (
1110
+ return_dict if return_dict is not None else self.config.use_return_dict
1111
+ )
1112
+
1113
+ if (input_ids is None) ^ (inputs_embeds is not None):
1114
+ raise ValueError(
1115
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
1116
+ )
1117
+
1118
+ if self.gradient_checkpointing and self.training and use_cache:
1119
+ logger.warning_once(
1120
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1121
+ )
1122
+ use_cache = False
1123
+
1124
+ if inputs_embeds is None:
1125
+ inputs_embeds = self.token_embeddings(input_ids)
1126
+
1127
+ past_seen_tokens = 0
1128
+ if use_cache: # kept for BC (cache positions)
1129
+ if not isinstance(past_key_values, StaticCache):
1130
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1131
+ past_seen_tokens = past_key_values.get_seq_length()
1132
+
1133
+ if cache_position is None:
1134
+ cache_position = torch.arange(
1135
+ past_seen_tokens,
1136
+ past_seen_tokens + inputs_embeds.shape[1],
1137
+ device=inputs_embeds.device,
1138
+ )
1139
+
1140
+ if position_ids is None:
1141
+ position_ids = cache_position.unsqueeze(0)
1142
+
1143
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
1144
+
1145
+ # embed positions
1146
+ hidden_states = inputs_embeds
1147
+
1148
+ # decoder layers
1149
+ all_hidden_states = () if output_hidden_states else None
1150
+ all_self_attns = () if output_attentions else None
1151
+ next_decoder_cache = None
1152
+
1153
+ for decoder_layer in self.layers:
1154
+ if output_hidden_states:
1155
+ all_hidden_states += (hidden_states,)
1156
+
1157
+ if self.gradient_checkpointing and self.training:
1158
+ layer_outputs = self._gradient_checkpointing_func(
1159
+ decoder_layer.__call__,
1160
+ hidden_states,
1161
+ causal_mask,
1162
+ position_ids,
1163
+ past_key_values,
1164
+ output_attentions,
1165
+ use_cache,
1166
+ cache_position,
1167
+ )
1168
+ else:
1169
+ layer_outputs = decoder_layer(
1170
+ hidden_states,
1171
+ attention_mask=causal_mask,
1172
+ position_ids=position_ids,
1173
+ past_key_value=past_key_values,
1174
+ output_attentions=output_attentions,
1175
+ use_cache=use_cache,
1176
+ cache_position=cache_position,
1177
+ )
1178
+
1179
+ hidden_states = layer_outputs[0]
1180
+
1181
+ if use_cache:
1182
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1183
+
1184
+ if output_attentions:
1185
+ all_self_attns += (layer_outputs[1],)
1186
+
1187
+ hidden_states = self.norm(hidden_states)
1188
+
1189
+ # add hidden states from the last decoder layer
1190
+ if output_hidden_states:
1191
+ all_hidden_states += (hidden_states,)
1192
+
1193
+ next_cache = None
1194
+ if use_cache:
1195
+ next_cache = (
1196
+ next_decoder_cache.to_legacy_cache()
1197
+ if isinstance(next_decoder_cache, Cache)
1198
+ else next_decoder_cache
1199
+ )
1200
+ if not return_dict:
1201
+ return tuple(
1202
+ v
1203
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
1204
+ if v is not None
1205
+ )
1206
+ return BaseModelOutputWithPast(
1207
+ last_hidden_state=hidden_states,
1208
+ past_key_values=next_cache,
1209
+ hidden_states=all_hidden_states,
1210
+ attentions=all_self_attns,
1211
+ )
1212
+
1213
+ def _update_causal_mask(self, attention_mask, input_tensor):
1214
+ if self.config._attn_implementation == "flash_attention_2":
1215
+ if attention_mask is not None and 0.0 in attention_mask:
1216
+ return attention_mask
1217
+ return None
1218
+
1219
+ batch_size, seq_length = input_tensor.shape[:2]
1220
+ dtype = input_tensor.dtype
1221
+ device = input_tensor.device
1222
+
1223
+ # support going beyond cached `max_position_embedding`
1224
+ if seq_length > self.causal_mask.shape[-1]:
1225
+ causal_mask = torch.full(
1226
+ (2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]),
1227
+ fill_value=1,
1228
+ )
1229
+ self.register_buffer(
1230
+ "causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False
1231
+ )
1232
+
1233
+ # We use the current dtype to avoid any overflows
1234
+ min_dtype = torch.finfo(dtype).min
1235
+ causal_mask = (
1236
+ self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype)
1237
+ * min_dtype
1238
+ )
1239
+
1240
+ causal_mask = causal_mask.to(dtype=dtype, device=device)
1241
+ if attention_mask is not None and attention_mask.dim() == 2:
1242
+ mask_length = attention_mask.shape[-1]
1243
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[
1244
+ :, None, None, :
1245
+ ].eq(0.0)
1246
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(
1247
+ padding_mask, min_dtype
1248
+ )
1249
+
1250
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None:
1251
+ # For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1252
+ is_tracing = (
1253
+ torch.jit.is_tracing()
1254
+ or isinstance(input_tensor, torch.fx.Proxy)
1255
+ or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1256
+ )
1257
+ if not is_tracing and torch.any(attention_mask != 1):
1258
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
1259
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1260
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1261
+ causal_mask = causal_mask.mul(
1262
+ ~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)
1263
+ ).to(dtype)
1264
+
1265
+ return causal_mask
1266
+
1267
+
1268
+ class OpenELMForCausalLM(OpenELMPreTrainedModel):
1269
+ _tied_weights_keys = ["lm_head.weight"]
1270
+
1271
+ def __init__(self, config: OpenELMConfig):
1272
+ super().__init__(config)
1273
+ self.transformer = OpenELMModel(config)
1274
+ self.vocab_size = config.vocab_size
1275
+ if config.share_input_output_layers:
1276
+ self.lm_head = None
1277
+ else:
1278
+ self.lm_head = nn.Linear(config.model_dim, config.vocab_size, bias=False)
1279
+
1280
+ # Initialize weights and apply final processing
1281
+ self.post_init()
1282
+
1283
+ def get_input_embeddings(self):
1284
+ return self.transformer.token_embeddings
1285
+
1286
+ def set_input_embeddings(self, value):
1287
+ self.transformer.token_embeddings = value
1288
+
1289
+ def get_output_embeddings(self):
1290
+ return self.lm_head
1291
+
1292
+ def set_output_embeddings(self, new_embeddings):
1293
+ self.lm_head = new_embeddings
1294
+
1295
+ def set_decoder(self, decoder):
1296
+ self.transformer = decoder
1297
+
1298
+ def get_decoder(self):
1299
+ return self.transformer
1300
+
1301
+ def forward(
1302
+ self,
1303
+ input_ids: torch.LongTensor = None,
1304
+ attention_mask: Optional[torch.Tensor] = None,
1305
+ position_ids: Optional[torch.LongTensor] = None,
1306
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1307
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1308
+ labels: Optional[torch.LongTensor] = None,
1309
+ use_cache: Optional[bool] = None,
1310
+ output_attentions: Optional[bool] = None,
1311
+ output_hidden_states: Optional[bool] = None,
1312
+ return_dict: Optional[bool] = None,
1313
+ cache_position: Optional[torch.LongTensor] = None,
1314
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1315
+ output_attentions = (
1316
+ output_attentions
1317
+ if output_attentions is not None
1318
+ else self.config.output_attentions
1319
+ )
1320
+ output_hidden_states = (
1321
+ output_hidden_states
1322
+ if output_hidden_states is not None
1323
+ else self.config.output_hidden_states
1324
+ )
1325
+ return_dict = (
1326
+ return_dict if return_dict is not None else self.config.use_return_dict
1327
+ )
1328
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1329
+ outputs = self.transformer(
1330
+ input_ids=input_ids,
1331
+ attention_mask=attention_mask,
1332
+ position_ids=position_ids,
1333
+ past_key_values=past_key_values,
1334
+ inputs_embeds=inputs_embeds,
1335
+ use_cache=use_cache,
1336
+ output_attentions=output_attentions,
1337
+ output_hidden_states=output_hidden_states,
1338
+ return_dict=return_dict,
1339
+ cache_position=cache_position,
1340
+ )
1341
+
1342
+ hidden_states = outputs[0]
1343
+ if self.lm_head is None:
1344
+ # shared
1345
+ logits = F.linear(
1346
+ hidden_states, weight=self.transformer.token_embeddings.weight
1347
+ )
1348
+ else:
1349
+ logits = self.lm_head(hidden_states)
1350
+ logits = logits[:, : self.config.vocab_size]
1351
+ loss = None
1352
+ if labels is not None:
1353
+ # Shift so that tokens < n predict n
1354
+ shift_logits = logits[..., :-1, :].contiguous()
1355
+ shift_labels = labels[..., 1:].contiguous()
1356
+ # Flatten the tokens
1357
+ loss_fct = CrossEntropyLoss()
1358
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1359
+ shift_labels = shift_labels.view(-1)
1360
+ # Enable model parallelism
1361
+ shift_labels = shift_labels.to(shift_logits.device)
1362
+ loss = loss_fct(shift_logits, shift_labels)
1363
+
1364
+ if not return_dict:
1365
+ output = (logits,) + outputs[1:]
1366
+ return (loss,) + output if loss is not None else output
1367
+
1368
+ return CausalLMOutputWithPast(
1369
+ loss=loss,
1370
+ logits=logits,
1371
+ past_key_values=outputs.past_key_values,
1372
+ hidden_states=outputs.hidden_states,
1373
+ attentions=outputs.attentions,
1374
+ )
1375
+
1376
+ def prepare_inputs_for_generation(
1377
+ self,
1378
+ input_ids,
1379
+ past_key_values=None,
1380
+ attention_mask=None,
1381
+ inputs_embeds=None,
1382
+ **kwargs,
1383
+ ):
1384
+ past_length = 0
1385
+ if past_key_values is not None:
1386
+ if isinstance(past_key_values, Cache):
1387
+ cache_length = past_key_values.get_seq_length()
1388
+ past_length = past_key_values.seen_tokens
1389
+ max_cache_length = past_key_values.get_max_length()
1390
+ else:
1391
+ cache_length = past_length = past_key_values[0][0].shape[2]
1392
+ max_cache_length = None
1393
+
1394
+ # Keep only the unprocessed tokens:
1395
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1396
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1397
+ # input)
1398
+ if (
1399
+ attention_mask is not None
1400
+ and attention_mask.shape[1] > input_ids.shape[1]
1401
+ ):
1402
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1403
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1404
+ # input_ids based on the past_length.
1405
+ elif past_length < input_ids.shape[1]:
1406
+ input_ids = input_ids[:, past_length:]
1407
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1408
+
1409
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1410
+ if (
1411
+ max_cache_length is not None
1412
+ and attention_mask is not None
1413
+ and cache_length + input_ids.shape[1] > max_cache_length
1414
+ ):
1415
+ attention_mask = attention_mask[:, -max_cache_length:]
1416
+
1417
+ position_ids = kwargs.get("position_ids", None)
1418
+ if attention_mask is not None and position_ids is None:
1419
+ # create position_ids on the fly for batch generation
1420
+ position_ids = attention_mask.long().cumsum(-1) - 1
1421
+ position_ids.masked_fill_(attention_mask == 0, 1)
1422
+ if past_key_values:
1423
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1424
+
1425
+ if self.generation_config.cache_implementation == "static":
1426
+ # generation with static cache
1427
+ cache_position = kwargs.get("cache_position", None)
1428
+ if cache_position is None:
1429
+ past_length = 0
1430
+ else:
1431
+ past_length = cache_position[-1] + 1
1432
+ input_ids = input_ids[:, past_length:]
1433
+ position_ids = position_ids[:, past_length:]
1434
+
1435
+ # we should only keep a `cache_position` in generate, and do +=1.
1436
+ # same goes for position ids. Could also help with continued generation.
1437
+ cache_position = torch.arange(
1438
+ past_length,
1439
+ past_length + position_ids.shape[-1],
1440
+ device=position_ids.device,
1441
+ )
1442
+
1443
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1444
+ if inputs_embeds is not None and past_key_values is None:
1445
+ model_inputs = {"inputs_embeds": inputs_embeds}
1446
+ else:
1447
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1448
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1449
+ # We could use `next_tokens` directly instead.
1450
+ model_inputs = {"input_ids": input_ids.contiguous()}
1451
+
1452
+ model_inputs.update(
1453
+ {
1454
+ "position_ids": position_ids.contiguous(),
1455
+ "cache_position": cache_position,
1456
+ "past_key_values": past_key_values,
1457
+ "use_cache": kwargs.get("use_cache"),
1458
+ "attention_mask": attention_mask,
1459
+ }
1460
+ )
1461
+ return model_inputs
1462
+
1463
+ @staticmethod
1464
+ def _reorder_cache(past_key_values, beam_idx):
1465
+ reordered_past = ()
1466
+ for layer_past in past_key_values:
1467
+ reordered_past += (
1468
+ tuple(
1469
+ past_state.index_select(0, beam_idx.to(past_state.device))
1470
+ for past_state in layer_past
1471
+ ),
1472
+ )
1473
+ return reordered_past
1474
+
1475
+
1476
+ ACT_TYPE = {
1477
+ 'relu': nn.ReLU,
1478
+ 'gelu': nn.GELU
1479
+ }
1480
+
1481
+ class Connector(nn.Module):
1482
+ def __init__(self, config=None):
1483
+ super().__init__()
1484
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.connector_type)
1485
+ act_type = config.connector_type.split('_')[-1]
1486
+ mlp_depth = int(mlp_gelu_match.group(1))
1487
+ modules = [nn.Linear(config.vision_hidden_size, config.hidden_size)]
1488
+ for _ in range(1, mlp_depth):
1489
+ modules.append(ACT_TYPE[act_type]())
1490
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
1491
+
1492
+ self._connector = nn.Sequential(*modules)
1493
+
1494
+ def forward(self, x):
1495
+ return self._connector(x)
1496
+
1497
+ class VisionTower(nn.Module):
1498
+ def __init__(self, cfg, model_name_or_path = 'clip'):
1499
+ super().__init__()
1500
+ self._vision_tower = AutoModel.from_pretrained(cfg.model_name_or_path, config = cfg, trust_remote_code=True)
1501
+ self._image_processor = AutoImageProcessor.from_pretrained(cfg.model_name_or_path)
1502
+ self.config = cfg
1503
+
1504
+
1505
+
1506
+ def forward(self, x, **kwargs):
1507
+ image_features = self._vision_tower(x, output_hidden_states=True)
1508
+ image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
1509
+
1510
+ if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch':
1511
+ image_features = image_features[:, 1:]
1512
+ elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch':
1513
+ image_features = image_features
1514
+ else:
1515
+ raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
1516
+
1517
+ return image_features
1518
+
1519
+
1520
+
1521
+ @property
1522
+ def vision_tower(self):
1523
+ return self._vision_tower
1524
+
1525
+ @vision_tower.setter
1526
+ def vision_tower(self, vision_tower):
1527
+ self._vision_tower = vision_tower
1528
+
1529
+ def get_value_from_kwargs(kwargs, name):
1530
+ if name in kwargs:
1531
+ return kwargs.pop(name)
1532
+ else:
1533
+ return None
1534
+
1535
+
1536
+
1537
+ class TinyLlavaPreTrainedModel(PreTrainedModel):
1538
+ config_class = TinyLlavaConfig
1539
+ base_model_prefix = "model"
1540
+ supports_gradient_checkpointing = True
1541
+ _no_split_modules = ["LlavaVisionAttention"]
1542
+ _skip_keys_device_placement = "past_key_values"
1543
+ _supports_flash_attn_2 = True
1544
+
1545
+ def _init_weights(self, module):
1546
+ std = (
1547
+ self.config.initializer_range
1548
+ if hasattr(self.config, "initializer_range")
1549
+ else self.config.text_config.initializer_range
1550
+ )
1551
+
1552
+ if hasattr(module, "class_embedding"):
1553
+ module.class_embedding.data.normal_(mean=0.0, std=std)
1554
+
1555
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
1556
+ module.weight.data.normal_(mean=0.0, std=std)
1557
+ if module.bias is not None:
1558
+ module.bias.data.zero_()
1559
+ elif isinstance(module, nn.Embedding):
1560
+ module.weight.data.normal_(mean=0.0, std=std)
1561
+ if module.padding_idx is not None:
1562
+ module.weight.data[module.padding_idx].zero_()
1563
+
1564
+ @property
1565
+ def _supports_sdpa(self):
1566
+ return self.language_model._supports_sdpa
1567
+
1568
+
1569
+ class TinyLlavaForConditionalGeneration(TinyLlavaPreTrainedModel):
1570
+ def __init__(self, config: TinyLlavaConfig):
1571
+
1572
+ super().__init__(config)
1573
+
1574
+ self.language_model = OpenELMForCausalLM(config.text_config)
1575
+ self.vision_tower = VisionTower(config.vision_config, config.vision_model_name_or_path)
1576
+ self.connector = Connector(config)
1577
+ self.post_init()
1578
+
1579
+
1580
+ def get_input_embeddings(self):
1581
+ return self.language_model.get_input_embeddings()
1582
+
1583
+ def set_input_embeddings(self, value):
1584
+ self.language_model.set_input_embeddings(value)
1585
+
1586
+ def get_output_embeddings(self):
1587
+ return self.language_model.get_output_embeddings()
1588
+
1589
+ def set_output_embeddings(self, new_embeddings):
1590
+ self.language_model.set_output_embeddings(new_embeddings)
1591
+
1592
+ def set_decoder(self, decoder):
1593
+ self.language_model.set_decoder(decoder)
1594
+
1595
+ def get_decoder(self):
1596
+ return self.language_model.get_decoder()
1597
+
1598
+ def tie_weights(self):
1599
+ return self.language_model.tie_weights()
1600
+
1601
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
1602
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
1603
+ # update vocab size
1604
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
1605
+ self.config.vocab_size = model_embeds.num_embeddings
1606
+ self.vocab_size = model_embeds.num_embeddings
1607
+ return model_embeds
1608
+
1609
+
1610
+ def forward(
1611
+ self,
1612
+ input_ids: torch.LongTensor = None,
1613
+ attention_mask: Optional[torch.Tensor] = None,
1614
+ position_ids: Optional[torch.LongTensor] = None,
1615
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1616
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1617
+ labels: Optional[torch.LongTensor] = None,
1618
+ use_cache: Optional[bool] = None,
1619
+ output_attentions: Optional[bool] = None,
1620
+ output_hidden_states: Optional[bool] = None,
1621
+ images: Optional[torch.FloatTensor] = None,
1622
+ image_sizes: Optional[List[List[int]]] = None,
1623
+ return_dict: Optional[bool] = None,
1624
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1625
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1626
+ if inputs_embeds is None:
1627
+ (
1628
+ input_ids,
1629
+ position_ids,
1630
+ attention_mask,
1631
+ past_key_values,
1632
+ inputs_embeds,
1633
+ labels
1634
+ ) = self.prepare_inputs_labels_for_multimodal(
1635
+ input_ids,
1636
+ position_ids,
1637
+ attention_mask,
1638
+ past_key_values,
1639
+ labels,
1640
+ images,
1641
+ image_sizes
1642
+ )
1643
+ return self.language_model.forward(
1644
+ input_ids=input_ids,
1645
+ attention_mask=attention_mask,
1646
+ position_ids=position_ids,
1647
+ past_key_values=past_key_values,
1648
+ inputs_embeds=inputs_embeds,
1649
+ labels=labels,
1650
+ use_cache=use_cache,
1651
+ output_attentions=output_attentions,
1652
+ output_hidden_states=output_hidden_states,
1653
+ return_dict=return_dict
1654
+ )
1655
+
1656
+ @torch.no_grad()
1657
+ def generate(
1658
+ self,
1659
+ inputs: Optional[torch.Tensor] = None,
1660
+ images: Optional[torch.Tensor] = None,
1661
+ image_sizes: Optional[torch.Tensor] = None,
1662
+ **kwargs,
1663
+ ) -> Union[GenerateOutput, torch.LongTensor]:
1664
+ position_ids = kwargs.pop("position_ids", None)
1665
+ attention_mask = kwargs.pop("attention_mask", None)
1666
+ if "inputs_embeds" in kwargs:
1667
+ raise NotImplementedError("`inputs_embeds` is not supported")
1668
+
1669
+ if images is not None:
1670
+ (
1671
+ inputs,
1672
+ position_ids,
1673
+ attention_mask,
1674
+ _,
1675
+ inputs_embeds,
1676
+ _
1677
+ ) = self.prepare_inputs_labels_for_multimodal(
1678
+ inputs,
1679
+ position_ids,
1680
+ attention_mask,
1681
+ None,
1682
+ None,
1683
+ images,
1684
+ image_sizes=image_sizes
1685
+ )
1686
+ else:
1687
+ inputs_embeds = self.language_model.get_input_embeddings()(inputs)
1688
+
1689
+ return self.language_model.generate(
1690
+ position_ids=position_ids,
1691
+ attention_mask=attention_mask,
1692
+ inputs_embeds=inputs_embeds,
1693
+ **kwargs
1694
+ )
1695
+
1696
+ def encode_images(self, images):
1697
+ kwargs = {}
1698
+ kwargs['vision_feature_layer'] = self.config.vision_feature_layer
1699
+ kwargs['vision_feature_select_strategy'] = self.config.vision_feature_select_strategy
1700
+ images = images.to(device=self.device, dtype=self.dtype)
1701
+ image_features = self.vision_tower(images, **kwargs)
1702
+ image_features = self.connector(image_features)
1703
+ return image_features
1704
+
1705
+
1706
+
1707
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
1708
+ inputs_embeds=None, **kwargs):
1709
+ images = kwargs.pop("images", None)
1710
+ image_sizes = kwargs.pop("image_sizes", None)
1711
+ inputs = self.language_model.prepare_inputs_for_generation(
1712
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
1713
+ )
1714
+ if images is not None:
1715
+ inputs['images'] = images
1716
+ if image_sizes is not None:
1717
+ inputs['image_sizes'] = image_sizes
1718
+ return inputs
1719
+
1720
+ def prepare_inputs_labels_for_multimodal(
1721
+ self, input_ids, position_ids, attention_mask, past_key_values, labels,
1722
+ images, image_sizes=None
1723
+ ):
1724
+ vision_tower = self.vision_tower
1725
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
1726
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
1727
+
1728
+
1729
+ image_features = self.encode_images(images)
1730
+
1731
+ # TODO: image start / end is not implemented here to support pretraining.
1732
+ if getattr(self.config, 'tune_mm_mlp_adapter', False):
1733
+ raise NotImplementedError
1734
+
1735
+ # Let's just add dummy tensors if they do not exist,
1736
+ # it is a headache to deal with None all the time.
1737
+ # But it is not ideal, and if you have a better idea,
1738
+ # please open an issue / submit a PR, thanks.
1739
+ _labels = labels
1740
+ _position_ids = position_ids
1741
+ _attention_mask = attention_mask
1742
+ if attention_mask is None:
1743
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1744
+ else:
1745
+ attention_mask = attention_mask.bool()
1746
+ if position_ids is None:
1747
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
1748
+ if labels is None:
1749
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
1750
+
1751
+ # remove the padding using attention_mask -- FIXME
1752
+ _input_ids = input_ids
1753
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
1754
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
1755
+
1756
+ new_input_embeds = []
1757
+ new_labels = []
1758
+ cur_image_idx = 0
1759
+ for batch_idx, cur_input_ids in enumerate(input_ids):
1760
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
1761
+ if num_images == 0:
1762
+ cur_image_features = image_features[cur_image_idx]
1763
+ cur_input_embeds_1 = self.language_model.get_input_embeddings()(cur_input_ids)
1764
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
1765
+ new_input_embeds.append(cur_input_embeds)
1766
+ new_labels.append(labels[batch_idx])
1767
+ cur_image_idx += 1
1768
+ continue
1769
+
1770
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
1771
+ cur_input_ids_noim = []
1772
+ cur_labels = labels[batch_idx]
1773
+ cur_labels_noim = []
1774
+ for i in range(len(image_token_indices) - 1):
1775
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
1776
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
1777
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
1778
+ cur_input_embeds = self.language_model.get_input_embeddings()(torch.cat(cur_input_ids_noim))
1779
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
1780
+ cur_new_input_embeds = []
1781
+ cur_new_labels = []
1782
+
1783
+ for i in range(num_images + 1):
1784
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
1785
+ cur_new_labels.append(cur_labels_noim[i])
1786
+ if i < num_images:
1787
+ cur_image_features = image_features[cur_image_idx]
1788
+ cur_image_idx += 1
1789
+ cur_new_input_embeds.append(cur_image_features)
1790
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
1791
+
1792
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
1793
+
1794
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
1795
+ cur_new_labels = torch.cat(cur_new_labels)
1796
+
1797
+ new_input_embeds.append(cur_new_input_embeds)
1798
+ new_labels.append(cur_new_labels)
1799
+
1800
+ # Truncate sequences to max length as image embeddings can make the sequence longer
1801
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
1802
+ if tokenizer_model_max_length is not None:
1803
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
1804
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
1805
+
1806
+ # Combine them
1807
+ max_len = max(x.shape[0] for x in new_input_embeds)
1808
+ batch_size = len(new_input_embeds)
1809
+
1810
+ new_input_embeds_padded = []
1811
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
1812
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
1813
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
1814
+
1815
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
1816
+ cur_len = cur_new_embed.shape[0]
1817
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
1818
+ new_input_embeds_padded.append(torch.cat((
1819
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
1820
+ cur_new_embed
1821
+ ), dim=0))
1822
+ if cur_len > 0:
1823
+ new_labels_padded[i, -cur_len:] = cur_new_labels
1824
+ attention_mask[i, -cur_len:] = True
1825
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
1826
+ else:
1827
+ new_input_embeds_padded.append(torch.cat((
1828
+ cur_new_embed,
1829
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
1830
+ ), dim=0))
1831
+ if cur_len > 0:
1832
+ new_labels_padded[i, :cur_len] = cur_new_labels
1833
+ attention_mask[i, :cur_len] = True
1834
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
1835
+
1836
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
1837
+
1838
+ if _labels is None:
1839
+ new_labels = None
1840
+ else:
1841
+ new_labels = new_labels_padded
1842
+
1843
+ if _attention_mask is None:
1844
+ attention_mask = None
1845
+ else:
1846
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
1847
+
1848
+ if _position_ids is None:
1849
+ position_ids = None
1850
+
1851
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
1852
+
1853
+ def chat(
1854
+ self,
1855
+ prompt: str,
1856
+ tokenizer = None,
1857
+ image: str = None,
1858
+ max_new_tokens: int = 512,
1859
+ num_beams = 1,
1860
+ top_p=None,
1861
+ temperature=0
1862
+ ):
1863
+ image_processor = self.vision_tower._image_processor
1864
+
1865
+ if image is not None:
1866
+ prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
1867
+ conv = conv_phi_v0.copy()
1868
+ conv.append_message(conv.roles[0], prompt)
1869
+ conv.append_message(conv.roles[1], None)
1870
+ prompt = conv.get_prompt()
1871
+ if image is not None:
1872
+ image = load_image(image)
1873
+ image_tensor = process_images(image, image_processor, self.config).to(self.device)
1874
+
1875
+ input_ids = (
1876
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
1877
+ .unsqueeze(0).to(self.device)
1878
+ )
1879
+ # Generate
1880
+ stime = time.time()
1881
+
1882
+ with torch.inference_mode():
1883
+ output_ids = self.generate(
1884
+ input_ids,
1885
+ images=image_tensor,
1886
+ do_sample=True if temperature > 0 else False,
1887
+ temperature=temperature,
1888
+ top_p=top_p,
1889
+ num_beams=num_beams,
1890
+ pad_token_id=tokenizer.pad_token_id,
1891
+ max_new_tokens=max_new_tokens,
1892
+ use_cache=True,
1893
+ # stopping_criteria=[stopping_criteria],
1894
+ )
1895
+
1896
+ # print('inference over')
1897
+ generation_time = time.time() - stime
1898
+ outputs = tokenizer.batch_decode(
1899
+ output_ids, skip_special_tokens=True
1900
+ )[0]
1901
+
1902
+ outputs = outputs.strip()
1903
+
1904
+ return outputs, generation_time
1905
+
1906
+
1907
+
1908
+
1909
+
1910
+ AutoConfig.register("tinyllava", TinyLlavaConfig)
1911
+ AutoModelForCausalLM.register(TinyLlavaConfig, TinyLlavaForConditionalGeneration)