xj commited on
Commit
00cb8cd
·
1 Parent(s): 10c7965

init commit

Browse files
added_tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "jxu124/TiO",
3
+ "auto_map": {
4
+ "AutoConfig": "modeling_tio.TiOConfig",
5
+ "AutoModel": "modeling_tio.TiOModel",
6
+ "AutoModelForCausalLM": "modeling_tio.TiOModel",
7
+ "AutoModelForSeq2SeqLM": "modeling_tio.TiOModel",
8
+ "AutoModelForSequenceClassification": "modeling_tio.TiOModel"
9
+ },
10
+ "activation_dropout": 0.0,
11
+ "activation_function": "gelu",
12
+ "add_type_embedding": true,
13
+ "architectures": [
14
+ "TiOModel"
15
+ ],
16
+ "attention_dropout": 0.0,
17
+ "attn_scale_factor": 2.0,
18
+ "bos_token_id": 0,
19
+ "classifier_dropout": 0.0,
20
+ "code_image_size": 128,
21
+ "code_layernorm_embedding": true,
22
+ "d_model": 1280,
23
+ "decoder_attention_heads": 16,
24
+ "decoder_drop_path_rate": 0.0,
25
+ "decoder_ffn_dim": 5120,
26
+ "decoder_layerdrop": 0.0,
27
+ "decoder_layers": 12,
28
+ "decoder_normalize_before": true,
29
+ "decoder_start_token_id": 0,
30
+ "dropout": 0.1,
31
+ "encoder_attention_heads": 16,
32
+ "encoder_drop_path_rate": 0.0,
33
+ "encoder_ffn_dim": 5120,
34
+ "encoder_layerdrop": 0.0,
35
+ "encoder_layers": 24,
36
+ "encoder_normalize_before": true,
37
+ "entangle_position_embedding": false,
38
+ "eos_token_id": 2,
39
+ "forced_eos_token_id": 2,
40
+ "image_bucket_size": 42,
41
+ "init_std": 0.02,
42
+ "is_encoder_decoder": true,
43
+ "layernorm_embedding": true,
44
+ "max_position_embeddings": 1024,
45
+ "model_type": "tio",
46
+ "normformer": true,
47
+ "num_hidden_layers": 24,
48
+ "pad_token_id": 1,
49
+ "patch_layernorm_embedding": true,
50
+ "resnet_drop_path_rate": 0.0,
51
+ "resnet_model_path": null,
52
+ "resnet_type": "resnet152",
53
+ "scale_embedding": false,
54
+ "share_decoder_input_output_embed": true,
55
+ "token_bucket_size": 256,
56
+ "torch_dtype": "float32",
57
+ "transformers_version": "4.28.0",
58
+ "use_cache": false,
59
+ "vocab_size": 59457
60
+ }
configuration_tio.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # [Apache-2.0] Modified from https://github.com/OFA-Sys/OFA
3
+ """ TiO model configuration"""
4
+ import warnings
5
+ from transformers import PretrainedConfig
6
+
7
+
8
+ class TiOConfig(PretrainedConfig):
9
+ r"""
10
+ This is the configuration class to store the configuration of a [`~TiOModel`]. It is used to instantiate an TiO
11
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
12
+ defaults will yield a similar configuration to that of the TiO.
13
+ architecture.
14
+
15
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
16
+ documentation from [`PretrainedConfig`] for more information.
17
+
18
+
19
+ Args:
20
+ vocab_size (`int`, *optional*, defaults to 50265):
21
+ Vocabulary size of the TiO model. Defines the number of different tokens that can be represented by the
22
+ `inputs_ids` passed when calling [`~TiOModel`] or [`~TFTiOModel`].
23
+ d_model (`int`, *optional*, defaults to 1024):
24
+ Dimension of the layers and the pooler layer.
25
+ encoder_layers (`int`, *optional*, defaults to 12):
26
+ Number of encoder layers.
27
+ decoder_layers (`int`, *optional*, defaults to 12):
28
+ Number of decoder layers.
29
+ encoder_attention_heads (`int`, *optional*, defaults to 16):
30
+ Number of attention heads for each attention layer in the Transformer encoder.
31
+ decoder_attention_heads (`int`, *optional*, defaults to 16):
32
+ Number of attention heads for each attention layer in the Transformer decoder.
33
+ decoder_ffn_dim (`int`, *optional*, defaults to 4096):
34
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
35
+ encoder_ffn_dim (`int`, *optional*, defaults to 4096):
36
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
37
+ activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
38
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
39
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
40
+ dropout (`float`, *optional*, defaults to 0.1):
41
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
42
+ attention_dropout (`float`, *optional*, defaults to 0.0):
43
+ The dropout ratio for the attention probabilities.
44
+ activation_dropout (`float`, *optional*, defaults to 0.0):
45
+ The dropout ratio for activations inside the fully connected layer.
46
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
47
+ The dropout ratio for classifier.
48
+ max_position_embeddings (`int`, *optional*, defaults to 1024):
49
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
50
+ just in case (e.g., 512 or 1024 or 2048).
51
+ init_std (`float`, *optional*, defaults to 0.02):
52
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
53
+ encoder_layerdrop: (`float`, *optional*, defaults to 0.0):
54
+ The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
55
+ for more details.
56
+ decoder_layerdrop: (`float`, *optional*, defaults to 0.0):
57
+ The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
58
+ for more details.
59
+ use_cache (`bool`, *optional*, defaults to `True`):
60
+ Whether or not the model should return the last key/values attentions (not used by all models).
61
+ """
62
+
63
+ model_type = "tio"
64
+ keys_to_ignore_at_inference = ["past_key_values"]
65
+
66
+ attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
67
+
68
+ def __init__(
69
+ self,
70
+ vocab_size=59457,
71
+ max_position_embeddings=1024,
72
+ encoder_layers=4,
73
+ encoder_ffn_dim=512 * 4,
74
+ encoder_attention_heads=8,
75
+ decoder_layers=4,
76
+ decoder_ffn_dim=512 * 4,
77
+ decoder_attention_heads=8,
78
+ encoder_layerdrop=0.0,
79
+ decoder_layerdrop=0.0,
80
+ use_cache=True,
81
+ is_encoder_decoder=True,
82
+ activation_function="gelu",
83
+ d_model=512,
84
+ dropout=0.1,
85
+ attention_dropout=0.0,
86
+ activation_dropout=0.0,
87
+ init_std=0.02,
88
+ classifier_dropout=0.0,
89
+ scale_embedding=False,
90
+ pad_token_id=1,
91
+ bos_token_id=0,
92
+ decoder_start_token_id=0,
93
+ eos_token_id=2,
94
+ forced_eos_token_id=2,
95
+ encoder_normalize_before=True,
96
+ decoder_normalize_before=True,
97
+ normformer=True,
98
+ encoder_drop_path_rate=0.0,
99
+ decoder_drop_path_rate=0.0,
100
+ layernorm_embedding=True,
101
+ patch_layernorm_embedding=True,
102
+ resnet_type="resnet101",
103
+ resnet_model_path=None,
104
+ resnet_drop_path_rate=0.0,
105
+ token_bucket_size=256,
106
+ image_bucket_size=42,
107
+ add_type_embedding=True,
108
+ share_decoder_input_output_embed=True,
109
+ attn_scale_factor=2.0,
110
+ code_layernorm_embedding=True,
111
+ code_image_size=128,
112
+ entangle_position_embedding=False,
113
+ **kwargs
114
+ ):
115
+ self.vocab_size = vocab_size
116
+ self.max_position_embeddings = max_position_embeddings
117
+ self.d_model = d_model
118
+ self.encoder_ffn_dim = encoder_ffn_dim
119
+ self.encoder_layers = encoder_layers
120
+ self.encoder_attention_heads = encoder_attention_heads
121
+ self.decoder_ffn_dim = decoder_ffn_dim
122
+ self.decoder_layers = decoder_layers
123
+ self.decoder_attention_heads = decoder_attention_heads
124
+ self.dropout = dropout
125
+ self.attention_dropout = attention_dropout
126
+ self.activation_dropout = activation_dropout
127
+ self.activation_function = activation_function
128
+ self.init_std = init_std
129
+ self.encoder_layerdrop = encoder_layerdrop
130
+ self.decoder_layerdrop = decoder_layerdrop
131
+ self.classifier_dropout = classifier_dropout
132
+ self.use_cache = use_cache
133
+ self.num_hidden_layers = encoder_layers
134
+ self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
135
+ self.encoder_normalize_before = encoder_normalize_before
136
+ self.decoder_normalize_before = decoder_normalize_before
137
+ self.normformer = normformer
138
+ self.encoder_drop_path_rate = encoder_drop_path_rate
139
+ self.decoder_drop_path_rate = decoder_drop_path_rate
140
+ self.layernorm_embedding = layernorm_embedding
141
+ self.patch_layernorm_embedding = patch_layernorm_embedding
142
+ self.resnet_type = resnet_type
143
+ self.resnet_model_path = resnet_model_path
144
+ self.resnet_drop_path_rate = resnet_drop_path_rate
145
+ self.token_bucket_size = token_bucket_size
146
+ self.image_bucket_size = image_bucket_size
147
+ self.add_type_embedding = add_type_embedding
148
+ self.share_decoder_input_output_embed = share_decoder_input_output_embed
149
+ self.attn_scale_factor = attn_scale_factor
150
+ self.code_layernorm_embedding = code_layernorm_embedding
151
+ self.code_image_size = code_image_size
152
+ self.entangle_position_embedding = entangle_position_embedding
153
+
154
+ super().__init__(
155
+ pad_token_id=pad_token_id,
156
+ bos_token_id=bos_token_id,
157
+ eos_token_id=eos_token_id,
158
+ is_encoder_decoder=is_encoder_decoder,
159
+ decoder_start_token_id=bos_token_id,
160
+ forced_eos_token_id=forced_eos_token_id,
161
+ **kwargs,
162
+ )
163
+
164
+ # ensure backward compatibility for BART CNN models
165
+ if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
166
+ self.forced_bos_token_id = self.bos_token_id
167
+ warnings.warn(
168
+ f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
169
+ "The config can simply be saved and uploaded again to be fixed."
170
+ )
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_tio.py ADDED
@@ -0,0 +1,2015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # [Apache-2.0] Modified from https://github.com/OFA-Sys/OFA
3
+ """ PyTorch TiO model."""
4
+
5
+ import math
6
+ import random
7
+ from typing import Optional, Tuple
8
+ from dataclasses import dataclass
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
14
+
15
+ from transformers.activations import ACT2FN
16
+ from transformers import PreTrainedModel
17
+ # from ...activations import ACT2FN
18
+ # from ...file_utils import (
19
+ # add_code_sample_docstrings,
20
+ # add_end_docstrings,
21
+ # add_start_docstrings,
22
+ # add_start_docstrings_to_model_forward,
23
+ # replace_return_docstrings,
24
+ # )
25
+ # from ...file_utils import ModelOutput
26
+ # from ...modeling_outputs import (
27
+ # BaseModelOutputWithPastAndCrossAttentions,
28
+ # Seq2SeqLMOutput,
29
+ # Seq2SeqModelOutput,
30
+ # )
31
+ # from ...modeling_utils import PreTrainedModel
32
+ # from ...utils import logging
33
+ from transformers.utils import logging, ModelOutput
34
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput
35
+ from .configuration_tio import TiOConfig
36
+ from .resnet import ResNet
37
+ from torch import Tensor
38
+ from typing import Dict, List, Optional, Tuple
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CONFIG_FOR_DOC = "TiOConfig"
43
+ _TOKENIZER_FOR_DOC = "TiOTokenizer"
44
+
45
+ DEFAULT_MAX_SOURCE_POSITIONS = 1024
46
+ DEFAULT_MAX_TARGET_POSITIONS = 1024
47
+
48
+ DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)
49
+
50
+ try:
51
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
52
+
53
+ has_fused_layernorm = True
54
+
55
+ class FusedLayerNorm(_FusedLayerNorm):
56
+ @torch.jit.unused
57
+ def forward(self, x):
58
+ if not x.is_cuda:
59
+ return super().forward(x)
60
+ else:
61
+ with torch.cuda.device(x.device):
62
+ return super().forward(x)
63
+
64
+ except ImportError:
65
+ has_fused_layernorm = False
66
+
67
+
68
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
69
+ r"""
70
+ Layer normalization.
71
+ If apex is available, use `FusedLayerNorm` instead.
72
+ """
73
+ if torch.jit.is_scripting():
74
+ export = True
75
+ if not export and torch.cuda.is_available() and has_fused_layernorm:
76
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
77
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
78
+
79
+
80
+ def make_token_bucket_position(bucket_size, max_position=DEFAULT_MAX_SOURCE_POSITIONS):
81
+ r"""
82
+ Make relative position indices for the text.
83
+ """
84
+ context_pos = torch.arange(max_position, dtype=torch.long)[:, None]
85
+ memory_pos = torch.arange(max_position, dtype=torch.long)[None, :]
86
+ relative_pos = context_pos - memory_pos
87
+ sign = torch.sign(relative_pos)
88
+ mid = bucket_size // 2
89
+ abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos))
90
+ log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position - 1) / mid) * (mid - 1)) + mid
91
+ log_pos = log_pos.int()
92
+ bucket_pos = torch.where(abs_pos.le(mid), relative_pos, log_pos * sign).long()
93
+ return bucket_pos + bucket_size - 1
94
+
95
+
96
+ def make_image_bucket_position(bucket_size, num_relative_distance):
97
+ r"""
98
+ Make relative position indices for the image.
99
+ """
100
+ coords_h = torch.arange(bucket_size)
101
+ coords_w = torch.arange(bucket_size)
102
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
103
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
104
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
105
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
106
+ relative_coords[:, :, 0] += bucket_size - 1 # shift to start from 0
107
+ relative_coords[:, :, 1] += bucket_size - 1
108
+ relative_coords[:, :, 0] *= 2 * bucket_size - 1
109
+ relative_position_index = torch.zeros(size=(bucket_size * bucket_size + 1,) * 2, dtype=relative_coords.dtype)
110
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
111
+ relative_position_index[0, 0:] = num_relative_distance - 3
112
+ relative_position_index[0:, 0] = num_relative_distance - 2
113
+ relative_position_index[0, 0] = num_relative_distance - 1
114
+ return relative_position_index
115
+
116
+
117
+ def new_arange(x, *size):
118
+ r"""
119
+ Return a Tensor of `size` filled with a range function on the device of x.
120
+ If size is empty, using the size of the variable x.
121
+ """
122
+ if len(size) == 0:
123
+ size = x.size()
124
+ return torch.arange(size[-1], device=x.device).expand(*size).contiguous()
125
+
126
+
127
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
128
+ r"""
129
+ Shift input ids one token to the right.
130
+ """
131
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
132
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
133
+ shifted_input_ids[:, 0] = decoder_start_token_id
134
+
135
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
136
+ # replace possible -100 values in labels by `pad_token_id`
137
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
138
+
139
+ return shifted_input_ids
140
+
141
+
142
+ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
143
+ r"""
144
+ Make causal mask used for uni-directional self-attention.
145
+ """
146
+ bsz, tgt_len = input_ids_shape
147
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min)
148
+ mask_cond = torch.arange(mask.size(-1))
149
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
150
+ mask = mask.to(dtype)
151
+
152
+ if past_key_values_length > 0:
153
+ mask = torch.cat([torch.ones(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
154
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
155
+
156
+
157
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
158
+ r"""
159
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
160
+ """
161
+ bsz, src_len = mask.size()
162
+ tgt_len = tgt_len if tgt_len is not None else src_len
163
+
164
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
165
+ inverted_mask = 1.0 - expanded_mask
166
+
167
+ return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
168
+
169
+
170
+ def Embedding(num_embeddings, embedding_dim, padding_idx=None, zero_init=False):
171
+ r"""
172
+ Embedding for tokens
173
+ """
174
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
175
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5)
176
+ if padding_idx is not None:
177
+ nn.init.constant_(m.weight[padding_idx], 0)
178
+ if zero_init:
179
+ nn.init.constant_(m.weight, 0)
180
+ return m
181
+
182
+
183
+ def Linear(in_features, out_features, bias=True):
184
+ r"""
185
+ Implementation of linear projection with xavier initialization
186
+ """
187
+ m = nn.Linear(in_features, out_features, bias)
188
+ nn.init.xavier_uniform_(m.weight)
189
+ if bias:
190
+ nn.init.constant_(m.bias, 0.0)
191
+ return m
192
+
193
+
194
+ class LayerDropModuleList(nn.ModuleList):
195
+ r"""
196
+ A LayerDrop implementation based on :class:`torch.nn.ModuleList`.
197
+
198
+ Args:
199
+ p (float): probability of dropping out each layer
200
+ modules (iterable, optional): an iterable of modules to add
201
+ """
202
+
203
+ def __init__(self, p, modules=None):
204
+ super().__init__(modules)
205
+ self.p = p
206
+
207
+ def __iter__(self):
208
+ dropout_probs = torch.empty(len(self)).uniform_()
209
+ for i, m in enumerate(super().__iter__()):
210
+ if not self.training or (dropout_probs[i] > self.p):
211
+ yield m
212
+
213
+
214
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
215
+ r"""
216
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
217
+
218
+ Args:
219
+ x (`nn.Modules`): input nn layers.
220
+ drop_prob (`float`): drop path ratio.
221
+ training (`bool`): whether is training or inference.
222
+ """
223
+ if drop_prob == 0.0 or not training:
224
+ return x
225
+ keep_prob = 1 - drop_prob
226
+ shape = (1, x.shape[1], 1)
227
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
228
+ random_tensor.floor_() # binarize
229
+ output = x.div(keep_prob) * random_tensor
230
+ return output
231
+
232
+
233
+ class DropPath(nn.Module):
234
+ r"""
235
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
236
+
237
+ Args:
238
+ drop_prob: drop path ratio.
239
+ """
240
+
241
+ def __init__(self, drop_prob=None):
242
+ super().__init__()
243
+ self.drop_prob = drop_prob
244
+
245
+ def forward(self, x):
246
+ return drop_path(x, self.drop_prob, self.training)
247
+
248
+ def extra_repr(self) -> str:
249
+ return "p={}".format(self.drop_prob)
250
+
251
+
252
+ class TiOAttention(nn.Module):
253
+ r"""
254
+ Multi-headed attention, with additional implementation for NormFormer.
255
+
256
+ Args:
257
+ embed_dim (`int`): embedding dimension.
258
+ num_heads (`int`): the number of attention heads.
259
+ dropout (`float32`): the ratio for dropout.
260
+ is_decoder (`bool`): whether or not decoder attention.
261
+ bias (`bool`): whether to add bias.
262
+ scale_heads (`bool`): whether to learn scaling heads, only for Normformer.
263
+ """
264
+
265
+ def __init__(
266
+ self,
267
+ embed_dim: int,
268
+ num_heads: int,
269
+ dropout: float = 0.0,
270
+ is_decoder: bool = False,
271
+ bias: bool = True,
272
+ scale_heads: bool = True,
273
+ ):
274
+ super().__init__()
275
+ self.embed_dim = embed_dim
276
+ self.num_heads = num_heads
277
+ self.dropout = dropout
278
+ self.head_dim = embed_dim // num_heads
279
+ assert (
280
+ self.head_dim * num_heads == self.embed_dim
281
+ ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
282
+ scale_factor=2
283
+ self.scaling = float(self.head_dim * scale_factor) ** -0.5
284
+ self.is_decoder = is_decoder
285
+
286
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
287
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
288
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
289
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
290
+ self.attn_dropout = nn.Dropout(p=dropout)
291
+ self.c_attn = nn.Parameter(torch.ones((self.num_heads,)), requires_grad=True) if scale_heads else None
292
+
293
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
294
+ r"""
295
+ Reshape tensors for multi-head attention.
296
+ """
297
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
298
+
299
+ def forward(
300
+ self,
301
+ hidden_states: torch.Tensor,
302
+ key_value_states: Optional[torch.Tensor] = None,
303
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
304
+ attention_mask: Optional[torch.Tensor] = None,
305
+ output_attentions: bool = False,
306
+ attn_bias: Optional[torch.Tensor] = None,
307
+ ):
308
+ r"""
309
+ Args:
310
+ hidden_states (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`)`: input states.
311
+ key_value_states (`torch.FloatTensor` of shape (bsz, tgt_len, embed_dim), *optional*): key value states.
312
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
313
+ cached past key value states for fast inference.
314
+ attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, seq_len)`, *optional*): attention mask.
315
+ output_attentions (`bool`, *optional*): whether to output attention weights of all layers.
316
+ attn_bias (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`, *optional*):
317
+ the attention bias for positional information.
318
+
319
+ Returns:
320
+ attn_output (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`): attention outputs.
321
+ attn_weights_reshaped (`torch.FloatTensor`, *optional*): attention weights of all layers.
322
+ past_key_value (`torch.FloatTensor`, *optional*): cached key value states for fast inference.
323
+ """
324
+
325
+ # if key_value_states are provided this layer is used as a cross-attention layer
326
+ # for the decoder
327
+ is_cross_attention = key_value_states is not None
328
+ bsz, tgt_len, embed_dim = hidden_states.size()
329
+
330
+ # get query proj
331
+ query_states = self.q_proj(hidden_states) * self.scaling
332
+ # get key, value proj
333
+ if is_cross_attention and past_key_value is not None:
334
+ # reuse k,v, cross_attentions
335
+ key_states = past_key_value[0]
336
+ value_states = past_key_value[1]
337
+ elif is_cross_attention:
338
+ # cross_attentions
339
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
340
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
341
+ elif past_key_value is not None:
342
+ # reuse k, v, self_attention
343
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
344
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
345
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
346
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
347
+ else:
348
+ # self_attention
349
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
350
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
351
+
352
+ if self.is_decoder:
353
+ past_key_value = (key_states, value_states)
354
+
355
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
356
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
357
+ key_states = key_states.view(*proj_shape)
358
+ value_states = value_states.view(*proj_shape)
359
+
360
+ src_len = key_states.size(1)
361
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
362
+
363
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
364
+ raise ValueError(
365
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
366
+ )
367
+
368
+ # Add attention bias for positional information
369
+ if attn_bias is not None:
370
+ attn_weights += attn_bias
371
+
372
+ if attention_mask is not None:
373
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
374
+ raise ValueError(
375
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
376
+ )
377
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
378
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
379
+
380
+ attn_weights = F.softmax(attn_weights, dim=-1)
381
+
382
+ if output_attentions:
383
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
384
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
385
+ else:
386
+ attn_weights_reshaped = None
387
+
388
+ attn_probs = self.attn_dropout(attn_weights)
389
+
390
+ attn_output = torch.bmm(attn_probs, value_states)
391
+
392
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
393
+ raise ValueError(
394
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
395
+ )
396
+
397
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
398
+ attn_output = attn_output.transpose(1, 2)
399
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
400
+
401
+ if self.c_attn is not None:
402
+ attn_output = attn_output.view(bsz, tgt_len, self.num_heads, self.head_dim)
403
+ attn_output = torch.einsum("bthd,h->bthd", attn_output, self.c_attn)
404
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
405
+
406
+ attn_output = self.out_proj(attn_output)
407
+
408
+ return attn_output, attn_weights_reshaped, past_key_value
409
+
410
+
411
+ class TiOEncoderLayer(nn.Module):
412
+ r"""
413
+ TiO encoder layer implementation.
414
+
415
+ Args:
416
+ config: configuration for TiO.
417
+ drop_path_rate: the ratio for drop path.
418
+ """
419
+
420
+ def __init__(self, config: TiOConfig, drop_path_rate=0.0):
421
+ super().__init__()
422
+ self.embed_dim = config.d_model
423
+ self.self_attn = TiOAttention(
424
+ embed_dim=self.embed_dim,
425
+ num_heads=config.encoder_attention_heads,
426
+ dropout=config.attention_dropout,
427
+ )
428
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim)
429
+ self.self_attn_mid_layer_norm = LayerNorm(self.embed_dim) if config.normformer else None
430
+ self.dropout = nn.Dropout(config.dropout)
431
+ self.activation_fn = ACT2FN[config.activation_function]
432
+ self.activation_dropout = nn.Dropout(config.activation_dropout)
433
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
434
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
435
+ self.ffn_layer_norm = LayerNorm(config.encoder_ffn_dim) if config.normformer else None
436
+ self.final_layer_norm = LayerNorm(self.embed_dim)
437
+ self.normalize_before = config.encoder_normalize_before
438
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
439
+
440
+ def residual_connection(self, x, residual):
441
+ r"""
442
+ Residual connection with drop path.
443
+ """
444
+ return residual + self.drop_path(x)
445
+
446
+ def forward(
447
+ self,
448
+ hidden_states: torch.Tensor,
449
+ attention_mask: torch.Tensor,
450
+ output_attentions: bool = False,
451
+ attn_bias: Optional[torch.Tensor] = None,
452
+ ):
453
+ r"""
454
+ Args:
455
+ hidden_states (`torch.FloatTensor`): input to the layer of shape *(bsz, src_len, embed_dim)*
456
+ attention_mask (`torch.FloatTensor`): attention mask of size
457
+ *(bsz, 1, src_len, src_len)* where padding elements are indicated by very large negative values.
458
+ output_attentions (`bool`, *optional*):
459
+ whether to return the attentions tensors of all attention layers. See `attentions` under
460
+ returned tensors for more detail.
461
+ attn_bias (`torch.FloatTensor`): bias for positional information.
462
+
463
+ Returns:
464
+ outputs (`tuple(torch.FloatTensor)`):
465
+ output hidden states of size (bsz, src_len, embed_dim), optionally with attention weights.
466
+ """
467
+
468
+ residual = hidden_states
469
+ if self.normalize_before:
470
+ hidden_states = self.self_attn_layer_norm(hidden_states)
471
+ hidden_states, attn_weights, _ = self.self_attn(
472
+ hidden_states=hidden_states,
473
+ attention_mask=attention_mask,
474
+ output_attentions=output_attentions,
475
+ attn_bias=attn_bias,
476
+ )
477
+ if self.self_attn_mid_layer_norm:
478
+ hidden_states = self.self_attn_mid_layer_norm(hidden_states)
479
+ hidden_states = self.dropout(hidden_states)
480
+ hidden_states = self.residual_connection(hidden_states, residual)
481
+ if not self.normalize_before:
482
+ hidden_states = self.self_attn_layer_norm(hidden_states)
483
+
484
+ residual = hidden_states
485
+
486
+ if self.normalize_before:
487
+ hidden_states = self.final_layer_norm(hidden_states)
488
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
489
+ hidden_states = self.activation_dropout(hidden_states)
490
+ if self.ffn_layer_norm:
491
+ hidden_states = self.ffn_layer_norm(hidden_states)
492
+ hidden_states = self.fc2(hidden_states)
493
+ hidden_states = self.dropout(hidden_states)
494
+ hidden_states = self.residual_connection(hidden_states, residual)
495
+ if not self.normalize_before:
496
+ hidden_states = self.final_layer_norm(hidden_states)
497
+
498
+ if hidden_states.dtype == torch.float16 and (
499
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
500
+ ):
501
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
502
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
503
+
504
+ outputs = (hidden_states,)
505
+
506
+ if output_attentions:
507
+ outputs += (attn_weights,)
508
+
509
+ return outputs
510
+
511
+
512
+ class TiODecoderLayer(nn.Module):
513
+ r"""
514
+ TiO decoder layer implementation.
515
+
516
+ Args:
517
+ config: configuration for TiO.
518
+ drop_path_rate: the ratio for drop path.
519
+ """
520
+
521
+ def __init__(self, config: TiOConfig, drop_path_rate=0.0):
522
+ super().__init__()
523
+ self.embed_dim = config.d_model
524
+
525
+ self.self_attn = TiOAttention(
526
+ embed_dim=self.embed_dim,
527
+ num_heads=config.decoder_attention_heads,
528
+ dropout=config.attention_dropout,
529
+ is_decoder=True,
530
+ )
531
+ self.dropout = nn.Dropout(p=config.dropout)
532
+ self.activation_fn = ACT2FN[config.activation_function]
533
+ self.activation_dropout = nn.Dropout(p=config.activation_dropout)
534
+
535
+ self.self_attn_layer_norm = LayerNorm(self.embed_dim)
536
+ self.self_attn_mid_layer_norm = LayerNorm(self.embed_dim) if config.normformer else None
537
+ self.cross_attn = TiOAttention(
538
+ self.embed_dim,
539
+ config.decoder_attention_heads,
540
+ dropout=config.attention_dropout,
541
+ is_decoder=True,
542
+ )
543
+ self.cross_attn_layer_norm = LayerNorm(self.embed_dim)
544
+ self.cross_attn_mid_layer_norm = LayerNorm(self.embed_dim) if config.normformer else None
545
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
546
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
547
+ self.ffn_layer_norm = LayerNorm(config.decoder_ffn_dim) if config.normformer else None
548
+ self.final_layer_norm = LayerNorm(self.embed_dim)
549
+ self.normalize_before = config.decoder_normalize_before
550
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
551
+
552
+ def residual_connection(self, x, residual):
553
+ r"""
554
+ Residual connection with drop path.
555
+ """
556
+ return residual + self.drop_path(x)
557
+
558
+ def forward(
559
+ self,
560
+ hidden_states: torch.Tensor,
561
+ attention_mask: Optional[torch.Tensor] = None,
562
+ encoder_hidden_states: Optional[torch.Tensor] = None,
563
+ encoder_attention_mask: Optional[torch.Tensor] = None,
564
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
565
+ output_attentions: Optional[bool] = False,
566
+ use_cache: Optional[bool] = False,
567
+ self_attn_bias: Optional[torch.Tensor] = None,
568
+ cross_attn_bias: Optional[torch.Tensor] = None,
569
+ ):
570
+ r"""
571
+ Args:
572
+ hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): input to the layer.
573
+ attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`):
574
+ attention mask where padding elements are indicated by very large negative values.
575
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch, seq_len, embed_dim)`):
576
+ cross attention input to the layer.
577
+ encoder_attention_mask (`torch.FloatTensor` of shape `(bsz, 1, tgt_len, src_len)`):
578
+ encoder attention mask where padding elements are indicated by very large negative values.
579
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
580
+ output_attentions (`bool`, *optional*): whether to return the attentions tensors of all attention layers.
581
+ use_cache (`bool`, *optional*): whether to use cache
582
+ self_attn_bias (`torch.FloatTensor`): self attention bias for positional information.
583
+ cross_attn_bias (`torch.FloatTensor`): cross attention bias for positional information.
584
+ """
585
+
586
+ # Self attention with intermediate layernorm
587
+ residual = hidden_states
588
+ if self.normalize_before:
589
+ hidden_states = self.self_attn_layer_norm(hidden_states)
590
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
591
+ # add present self-attn cache to position 1,2 of present_key_value tuple
592
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
593
+ hidden_states=hidden_states,
594
+ past_key_value=self_attn_past_key_value,
595
+ attention_mask=attention_mask,
596
+ output_attentions=output_attentions,
597
+ attn_bias=self_attn_bias,
598
+ )
599
+ if self.self_attn_mid_layer_norm:
600
+ hidden_states = self.self_attn_mid_layer_norm(hidden_states)
601
+ hidden_states = self.dropout(hidden_states)
602
+ hidden_states = self.residual_connection(hidden_states, residual)
603
+ if not self.normalize_before:
604
+ hidden_states = self.self_attn_layer_norm(hidden_states)
605
+
606
+ # Cross attention with intermediate layernorm
607
+ cross_attn_present_key_value = None
608
+ cross_attn_weights = None
609
+ if encoder_hidden_states is not None:
610
+ residual = hidden_states
611
+ if self.normalize_before:
612
+ hidden_states = self.cross_attn_layer_norm(hidden_states)
613
+ # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
614
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
615
+ hidden_states, cross_attn_weights, cross_attn_present_key_value = self.cross_attn(
616
+ hidden_states=hidden_states,
617
+ key_value_states=encoder_hidden_states,
618
+ attention_mask=encoder_attention_mask,
619
+ past_key_value=cross_attn_past_key_value,
620
+ output_attentions=output_attentions,
621
+ attn_bias=cross_attn_bias,
622
+ )
623
+ if self.cross_attn_mid_layer_norm:
624
+ hidden_states = self.cross_attn_mid_layer_norm(hidden_states)
625
+ hidden_states = self.dropout(hidden_states)
626
+ hidden_states = self.residual_connection(hidden_states, residual)
627
+ if not self.normalize_before:
628
+ hidden_states = self.cross_attn_layer_norm(hidden_states)
629
+
630
+ # add cross-attn to positions 3,4 of present_key_value tuple
631
+ present_key_value = present_key_value + cross_attn_present_key_value
632
+
633
+ # FFN with intermediate layernorm
634
+ residual = hidden_states
635
+ if self.normalize_before:
636
+ hidden_states = self.final_layer_norm(hidden_states)
637
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
638
+ hidden_states = self.activation_dropout(hidden_states)
639
+ if self.ffn_layer_norm:
640
+ hidden_states = self.ffn_layer_norm(hidden_states)
641
+ hidden_states = self.fc2(hidden_states)
642
+ hidden_states = self.dropout(hidden_states)
643
+ hidden_states = self.residual_connection(hidden_states, residual)
644
+ if not self.normalize_before:
645
+ hidden_states = self.final_layer_norm(hidden_states)
646
+
647
+ outputs = (hidden_states,)
648
+
649
+ if output_attentions:
650
+ outputs += (self_attn_weights, cross_attn_weights)
651
+
652
+ if use_cache:
653
+ outputs += (present_key_value,)
654
+
655
+ return outputs
656
+
657
+
658
+ class TiOPreTrainedModel(PreTrainedModel):
659
+ r"""
660
+ Base class TiO
661
+ """
662
+
663
+ config_class = TiOConfig
664
+ base_model_prefix = "model"
665
+ supports_gradient_checkpointing = True
666
+
667
+ def _init_weights(self, module):
668
+ r"""
669
+ Weight initialization which follows BERT.
670
+ """
671
+ std = self.config.init_std
672
+ if isinstance(module, nn.Linear):
673
+ module.weight.data.normal_(mean=0.0, std=std)
674
+ if module.bias is not None:
675
+ module.bias.data.zero_()
676
+ elif isinstance(module, nn.Embedding):
677
+ module.weight.data.normal_(mean=0.0, std=std)
678
+ if module.padding_idx is not None:
679
+ module.weight.data[module.padding_idx].zero_()
680
+
681
+ def _set_gradient_checkpointing(self, module, value=False):
682
+ r"""
683
+ Turn on the switch of gradient checkpointing.
684
+ """
685
+ if isinstance(module, (TiODecoder, TiOEncoder)):
686
+ module.gradient_checkpointing = value
687
+
688
+
689
+ @dataclass
690
+ class TiOEncoderOutput(ModelOutput):
691
+ r"""
692
+ Base class for TiO's outputs.
693
+
694
+ Args:
695
+ last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`):
696
+ Sequence of hidden-states at the output of the last layer of the model.
697
+
698
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed
699
+ or when `config.output_hidden_states=True`):
700
+
701
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
702
+ shape `(bsz, seq_len, hidden)`.
703
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
704
+
705
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed
706
+ or when `config.output_attentions=True`):
707
+
708
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(bsz, num_heads, seq_len, seq_len)`.
709
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
710
+ heads.
711
+
712
+ position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`):
713
+ postional embeddings of the inputs.
714
+ """
715
+
716
+ last_hidden_state: torch.FloatTensor = None
717
+ padding_mask: torch.Tensor = None
718
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
719
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
720
+ position_embedding: Optional[torch.FloatTensor] = None
721
+
722
+
723
+ TiO_START_DOCSTRING = r"""
724
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
725
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
726
+ etc.)
727
+
728
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
729
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
730
+ and behavior.
731
+
732
+ Parameters:
733
+ config ([`~TiOConfig`]):
734
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
735
+ load the weights associated with the model, only the configuration. Check out the
736
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
737
+ """
738
+
739
+
740
+ TiO_GENERATION_EXAMPLE = r"""
741
+ Image captioning example:
742
+
743
+ ```python
744
+ >>> from PIL import Image
745
+ >>> from torchvision import transforms
746
+ >>> from transformers import TiOTokenizer, TiOForConditionalGeneration
747
+
748
+ >>> mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
749
+ >>> resolution = 256
750
+ >>> patch_resize_transform = transforms.Compose([
751
+ lambda image: image.convert("RGB"),
752
+ transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC),
753
+ transforms.ToTensor(),
754
+ transforms.Normalize(mean=mean, std=std)
755
+ ])
756
+
757
+ >>> model = TiOForConditionalGeneration.from_pretrained(ckpt_dir)
758
+ >>> tokenizer = TiOTokenizer.from_pretrained(ckpt_dir)
759
+
760
+ >>> txt = " what is the description of the image?"
761
+ >>> inputs = tokenizer([txt], max_length=1024, return_tensors="pt")["input_ids"]
762
+ >>> img = Image.open(path_to_image)
763
+ >>> patch_img = patch_resize_transform(img).unsqueeze(0)
764
+
765
+ >>> gen = model.generate(inputs, patch_img=patch_img, num_beams=4)
766
+ >>> print(tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False))
767
+ ```
768
+ """
769
+
770
+
771
+ TiO_INPUTS_DOCSTRING = r"""
772
+ Args:
773
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`):
774
+ indices of input sequence tokens in the vocabular, and padding will be ignored by default;
775
+
776
+ indices can be obtained using [`~TiOTokenizer`].
777
+
778
+ patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
779
+ the resized image, which are transformed by the default operations.
780
+ patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
781
+ the second (if it exists) image.
782
+ patch_masks (`torch.BoolTensor`): the patches to be masked.
783
+ token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings.
784
+ sample_patch_num (`int`): the number of patches to sample.
785
+ decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary.
786
+ code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation.
787
+ attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding.
788
+ encoder_outputs (`TiOEncoderOutput`):
789
+ encoder outputs with hidden states, positional embeddings, and padding masks.
790
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed):
791
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
792
+ shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of
793
+ shape `(bsz, num_heads, src_len, head_size)`.
794
+ use_cache (`bool`): whether to use cache for faster inference.
795
+ output_attentions (`bool`): whether to output attention weights.
796
+ output_hidden_states (`bool`): whether to output hidden states.
797
+ return_dict (`bool`): unused. Keep it for generation only.
798
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
799
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
800
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
801
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
802
+ """
803
+
804
+
805
+ class TiOEncoder(TiOPreTrainedModel):
806
+ r"""
807
+ TiO encoder consisting of layers of [`TiOEncoderLayer`].
808
+
809
+ Args:
810
+ config: TiOConfig
811
+ embed_tokens (`nn.Embedding`, *optional*): output embedding
812
+ """
813
+
814
+ def __init__(self, config: TiOConfig, embed_tokens: Optional[nn.Embedding] = None):
815
+ super().__init__(config)
816
+
817
+ self.dropout = nn.Dropout(config.dropout)
818
+ self.encoder_layerdrop = config.encoder_layerdrop
819
+
820
+ embed_dim = config.d_model
821
+ self.padding_idx = config.pad_token_id
822
+ self.max_source_positions = config.max_position_embeddings
823
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
824
+ self.num_attention_heads = config.encoder_attention_heads
825
+
826
+ if getattr(config, "layernorm_embedding", False):
827
+ self.layernorm_embedding = LayerNorm(embed_dim)
828
+ else:
829
+ self.layernorm_embedding = None
830
+
831
+ if embed_tokens is not None:
832
+ self.embed_tokens = embed_tokens
833
+ else:
834
+ self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
835
+
836
+ if config.add_type_embedding:
837
+ self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
838
+ else:
839
+ self.type_embedding = None
840
+
841
+ if config.resnet_type == "resnet18":
842
+ self.embed_images = ResNet([2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
843
+ elif config.resnet_type == "resnet34":
844
+ self.embed_images = ResNet([3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
845
+ elif config.resnet_type == "resnet50":
846
+ self.embed_images = ResNet([3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
847
+ elif config.resnet_type == "resnet101":
848
+ self.embed_images = ResNet([3, 4, 23], drop_path_rate=config.resnet_drop_path_rate)
849
+ elif config.resnet_type == "resnet152":
850
+ self.embed_images = ResNet([3, 8, 36], drop_path_rate=config.resnet_drop_path_rate)
851
+ else:
852
+ raise NotImplementedError
853
+ self.image_proj = Linear(1024, embed_dim)
854
+
855
+ if config.resnet_model_path:
856
+ resnet_state_dict = torch.load(config.resnet_model_path)
857
+ self.embed_images.load_state_dict(resnet_state_dict)
858
+ if config.patch_layernorm_embedding:
859
+ self.patch_layernorm_embedding = LayerNorm(embed_dim)
860
+ else:
861
+ self.patch_layernorm_embedding = None
862
+
863
+ self.embed_positions = Embedding(self.max_source_positions + 2, embed_dim)
864
+ self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, embed_dim)
865
+ self.pos_ln = LayerNorm(embed_dim)
866
+ self.image_pos_ln = LayerNorm(embed_dim)
867
+ self.pos_scaling = float(embed_dim / self.num_attention_heads * config.attn_scale_factor) ** -0.5
868
+ self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
869
+ self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
870
+
871
+ if self.encoder_layerdrop > 0.0:
872
+ self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
873
+ else:
874
+ self.layers = nn.ModuleList([])
875
+
876
+ dpr = [x.item() for x in torch.linspace(0, config.encoder_drop_path_rate, config.encoder_layers)]
877
+ self.layers.extend(
878
+ [TiOEncoderLayer(config, drop_path_rate=dpr[i]) for i in range(config.encoder_layers)]
879
+ )
880
+ self.num_layers = len(self.layers)
881
+
882
+ if config.encoder_normalize_before:
883
+ self.layer_norm = LayerNorm(embed_dim)
884
+ else:
885
+ self.layer_norm = None
886
+
887
+ self.token_bucket_size = config.token_bucket_size
888
+ token_num_rel_dis = 2 * config.token_bucket_size - 1
889
+ token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
890
+ self.token_rel_pos_table_list = nn.ModuleList(
891
+ [Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in
892
+ range(config.encoder_layers)]
893
+ )
894
+
895
+ self.image_bucket_size = config.image_bucket_size
896
+ image_num_rel_dis = (2 * config.image_bucket_size - 1) * (2 * config.image_bucket_size - 1) + 3
897
+ image_rp_bucket = make_image_bucket_position(config.image_bucket_size, image_num_rel_dis)
898
+ self.image_rel_pos_table_list = nn.ModuleList(
899
+ [Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True) for _ in
900
+ range(config.encoder_layers)]
901
+ )
902
+
903
+ if config.layernorm_embedding:
904
+ self.layernorm_embedding = LayerNorm(embed_dim)
905
+ else:
906
+ self.layernorm_embedding = None
907
+
908
+ self.register_buffer("token_rp_bucket", token_rp_bucket)
909
+ self.register_buffer("image_rp_bucket", image_rp_bucket)
910
+ self.entangle_position_embedding = config.entangle_position_embedding
911
+
912
+ self.gradient_checkpointing = False
913
+ # Initialize weights and apply final processing
914
+ self.post_init()
915
+
916
+ def get_input_embeddings(self):
917
+ r"""
918
+ Get the embedding weight.
919
+ """
920
+ return self.embed_tokens
921
+
922
+ def set_input_embeddings(self, value):
923
+ r"""
924
+ Set the weight of embedding with the given tensor.
925
+ """
926
+ self.embed_tokens = value
927
+
928
+ def get_rel_pos_bias(self, x, idx):
929
+ r"""
930
+ Get the relative positional bias of the text, for attention.
931
+ """
932
+
933
+ seq_len = x.size(1)
934
+ rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
935
+ values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
936
+ values = values.unsqueeze(0).expand(x.size(0), -1, -1, -1)
937
+ values = values.permute([0, 3, 1, 2])
938
+ return values.contiguous()
939
+
940
+ def get_image_rel_pos_bias(self, image_position_ids, idx):
941
+ r"""
942
+ Get the relative positional bias of the image, for attention.
943
+ """
944
+
945
+ bsz, seq_len = image_position_ids.shape
946
+ rp_bucket_size = self.image_rp_bucket.size(1)
947
+
948
+ rp_bucket = self.image_rp_bucket.unsqueeze(0).expand(
949
+ bsz, rp_bucket_size, rp_bucket_size
950
+ ).gather(1, image_position_ids[:, :, None].expand(bsz, seq_len, rp_bucket_size)
951
+ ).gather(2, image_position_ids[:, None, :].expand(bsz, seq_len, seq_len))
952
+ values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
953
+ values = values.permute(0, 3, 1, 2)
954
+ return values
955
+
956
+ def get_patch_images_info(self, patch_images, sample_patch_num, device):
957
+ r"""
958
+ Get the basic information of the resized image.
959
+
960
+ Args:
961
+ patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`): the resized image.
962
+ sample_patch_num (`int`):
963
+ the number of patches to sample. If it is equal to -1, no sampling will be performed.
964
+ device: GPU device.
965
+
966
+ Returns:
967
+ image_embed (`torch.FloatTensor` of shape `(bsz, h * w, hidden)`): the output of the visual encoder.
968
+ image_num_patches (`int`, equal to `h * w`): the number of patches.
969
+ image_padding_mask (`torch.BooleanTensor` of shape `(bsz, h*w)`): image padding mask.
970
+ image_position_ids (`torch.LongTensor` of shape `(bsz, h*w)`): image position ids.
971
+ image_pos_embed (`torch.FloatTensor` of shape (bsz, h*w, hidden)): the positional embedding.
972
+ """
973
+
974
+ image_embed = self.embed_images(patch_images)
975
+ h, w = image_embed.shape[-2:]
976
+ image_num_patches = h * w
977
+ image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool()
978
+ image_position_idx = torch.arange(w).unsqueeze(0).expand(h, w) + \
979
+ torch.arange(h).unsqueeze(1) * self.image_bucket_size + 1
980
+ image_position_idx = image_position_idx.view(-1).to(device)
981
+ image_position_ids = image_position_idx[None, :].expand(patch_images.size(0), image_num_patches)
982
+
983
+ image_embed = image_embed.flatten(2).transpose(1, 2)
984
+ if sample_patch_num is not None:
985
+ patch_orders = [
986
+ random.sample(range(image_num_patches), k=sample_patch_num)
987
+ for _ in range(patch_images.size(0))
988
+ ]
989
+ patch_orders = torch.LongTensor(patch_orders).to(device)
990
+ image_embed = image_embed.gather(
991
+ 1, patch_orders.unsqueeze(2).expand(-1, -1, image_embed.size(2))
992
+ )
993
+ image_num_patches = sample_patch_num
994
+ image_padding_mask = image_padding_mask.gather(1, patch_orders)
995
+ image_position_ids = image_position_ids.gather(1, patch_orders)
996
+ image_pos_embed = self.embed_image_positions(image_position_ids)
997
+
998
+ return image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed
999
+
1000
+ def forward_embedding(
1001
+ self,
1002
+ input_ids,
1003
+ image_embed: Optional[torch.Tensor] = None,
1004
+ image_embed_2: Optional[torch.Tensor] = None,
1005
+ token_embedding: Optional[torch.Tensor] = None,
1006
+ pos_embed: Optional[torch.Tensor] = None,
1007
+ image_pos_embed: Optional[torch.Tensor] = None,
1008
+ image_pos_embed_2: Optional[torch.Tensor] = None
1009
+ ):
1010
+ r"""
1011
+ Generate embeddings of both the image and the text.
1012
+ Actually since TiO unifies both unimodal and multimodal data,
1013
+ image inputs are optional.
1014
+
1015
+ Args:
1016
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the tokens in the vocabulary.
1017
+ image_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*): image embeddings.
1018
+ image_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*):
1019
+ image embeddings of the second image (if it exists).
1020
+ token_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`, *optional*):
1021
+ input token embeddings to replace the embeddings of input ids.
1022
+ image_pos_embed (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*):
1023
+ positional embeddings of the image.
1024
+ image_pos_embed_2 (`torch.FloatTensor` of shape `(bsz, h*w, embed_dim)`, *optional*):
1025
+ positional embeddings of the second image.
1026
+
1027
+ Returns:
1028
+ x (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`): embeddings of the input.
1029
+ embed (`torch.FloatTensor` of shape `(bsz, h*w+seq_len, embed_dim)`):
1030
+ embeddings without adding positional and type embeddings.
1031
+ """
1032
+
1033
+ # embed tokens and positions
1034
+ if token_embedding is None:
1035
+ token_embedding = self.embed_tokens(input_ids)
1036
+ x = embed = self.embed_scale * token_embedding
1037
+ if self.entangle_position_embedding and pos_embed is not None:
1038
+ x += pos_embed
1039
+ if self.type_embedding is not None:
1040
+ x += self.type_embedding(input_ids.new_zeros(x.size()[:2]))
1041
+ if self.layernorm_embedding is not None:
1042
+ x = self.layernorm_embedding(x)
1043
+ x = self.dropout(x)
1044
+
1045
+ # embed raw images
1046
+ if image_embed is not None:
1047
+ image_embed = self.image_proj(image_embed)
1048
+ image_x = image_embed = self.embed_scale * image_embed
1049
+ if self.entangle_position_embedding and image_pos_embed is not None:
1050
+ image_x += image_pos_embed
1051
+ if self.type_embedding is not None:
1052
+ image_x += self.type_embedding(input_ids.new_ones(image_x.size()[:2]))
1053
+ if self.patch_layernorm_embedding is not None:
1054
+ image_x = self.patch_layernorm_embedding(image_x)
1055
+ image_x = self.dropout(image_x)
1056
+ x = torch.cat([image_x, x], dim=1)
1057
+ embed = torch.cat([image_embed, embed], dim=1)
1058
+
1059
+ if image_embed_2 is not None:
1060
+ assert self.type_embedding is not None
1061
+ image_embed_2 = self.image_proj(image_embed_2)
1062
+ image_x_2 = image_embed_2 = self.embed_scale * image_embed_2
1063
+ if self.entangle_position_embedding and image_pos_embed_2 is not None:
1064
+ image_x_2 += image_pos_embed_2
1065
+ if self.type_embedding is not None:
1066
+ image_x_2 += self.type_embedding(input_ids.new_full(image_x_2.size()[:2], fill_value=2))
1067
+ if self.patch_layernorm_embedding is not None:
1068
+ image_x_2 = self.patch_layernorm_embedding(image_x_2)
1069
+ image_x_2 = self.dropout(image_x_2)
1070
+ if self.quant_noise is not None:
1071
+ image_x_2 = self.quant_noise(image_x_2)
1072
+ x = torch.cat([image_x_2, x], dim=1)
1073
+ embed = torch.cat([image_embed_2, embed], dim=1)
1074
+
1075
+ return x, embed
1076
+
1077
+ def reorder_encoder_out(self, encoder_out, new_order):
1078
+ """
1079
+ Reorder encoder output according to *new_order*.
1080
+
1081
+ Args:
1082
+ encoder_out: output from the ``forward()`` method
1083
+ new_order (LongTensor): desired order
1084
+
1085
+ Returns:
1086
+ *encoder_out* rearranged according to *new_order*
1087
+ """
1088
+
1089
+ if "last_hidden_state" not in encoder_out:
1090
+ new_encoder_out = None
1091
+ else:
1092
+ new_encoder_out = encoder_out["last_hidden_state"].index_select(0, new_order)
1093
+
1094
+ if "padding_mask" not in encoder_out:
1095
+ new_encoder_padding_mask = None
1096
+ else:
1097
+ new_encoder_padding_mask = encoder_out["padding_mask"].index_select(0, new_order)
1098
+
1099
+
1100
+ if "position_embedding" not in encoder_out:
1101
+ new_position_embeddings = None
1102
+ else:
1103
+ new_position_embeddings = encoder_out["position_embedding"].index_select(0, new_order)
1104
+
1105
+ if "hidden_states" not in encoder_out:
1106
+ new_encoer_states = None
1107
+ else:
1108
+ encoder_states = encoder_out["hidden_states"]
1109
+ new_encoer_states = ()
1110
+ if len(encoder_states) > 0:
1111
+ for idx, state in enumerate(encoder_states):
1112
+ new_encoer_states += (state.index_select(0, new_order),)
1113
+
1114
+ if "attentions" not in encoder_out:
1115
+ attentions = None
1116
+ else:
1117
+ attentions = encoder_out["attentions"]
1118
+
1119
+ return TiOEncoderOutput(
1120
+ last_hidden_state=new_encoder_out,
1121
+ padding_mask=new_encoder_padding_mask,
1122
+ hidden_states=new_encoer_states,
1123
+ attentions=attentions,
1124
+ position_embedding=new_position_embeddings
1125
+ )
1126
+
1127
+ def forward(
1128
+ self,
1129
+ input_ids=None,
1130
+ patch_images: Optional[torch.Tensor] = None,
1131
+ patch_images_2: Optional[torch.Tensor] = None,
1132
+ patch_masks: Optional[torch.Tensor] = None,
1133
+ output_attentions: bool = False,
1134
+ output_hidden_states: bool = False,
1135
+ token_embeddings: Optional[torch.Tensor] = None,
1136
+ sample_patch_num: Optional[int] = None,
1137
+ ):
1138
+ r"""
1139
+ Args:
1140
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`):
1141
+ indices of input sequence tokens in the vocabular, and padding will be ignored by default;
1142
+
1143
+ indices can be obtained using [`~TiOTokenizer`].
1144
+
1145
+ patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
1146
+ the resized image, which are transformed by the default operations.
1147
+ patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
1148
+ the second (if it exists) image.
1149
+ patch_masks (`torch.BoolTensor`): the patches to be masked.
1150
+ output_attentions (`bool`): whether to return all attention weights,
1151
+ output_hidden_states (`bool`): whether to return all hidden states.
1152
+ token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings.
1153
+ sample_patch_num (`int`): the number of patches to sample.
1154
+
1155
+ Returns:
1156
+ [`TiOEncoderOutput`]:
1157
+ last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1158
+ the states of the last layer.
1159
+ padding_mask (`torch.BoolTensor` of shape `(bsz, seq_len)`):
1160
+ the padding mask of the source context.
1161
+ hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1162
+ the states of all layers including the embeddings.
1163
+ attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`):
1164
+ the attention weights of all layers.
1165
+ position_embedding (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1166
+ positional embeddings of the input image and tokens.
1167
+ """
1168
+
1169
+ image_embed = None
1170
+ image_embed_2 = None
1171
+ image_pos_embed = None
1172
+ image_pos_embed_2 = None
1173
+ if patch_images is not None:
1174
+ image_embed, image_num_patches, image_padding_mask, image_position_ids, image_pos_embed = \
1175
+ self.get_patch_images_info(patch_images, sample_patch_num, input_ids.device)
1176
+ # image_padding_mask[~patch_masks] = True # comment the line to temporarily fix the bug of mismatch
1177
+ if patch_images_2 is not None:
1178
+ image_embed_2, image_num_patches_2, image_padding_mask_2, image_position_ids_2, image_pos_embed_2 = \
1179
+ self.get_patch_images_info(patch_images_2, sample_patch_num, input_ids.device)
1180
+ image_padding_mask_2[~patch_masks] = True
1181
+
1182
+ encoder_padding_mask = input_ids.eq(self.padding_idx)
1183
+ if patch_images is not None:
1184
+ encoder_padding_mask = torch.cat([image_padding_mask, encoder_padding_mask], dim=1)
1185
+ if patch_images_2 is not None:
1186
+ encoder_padding_mask = torch.cat([image_padding_mask_2, encoder_padding_mask], dim=1)
1187
+ has_pads = encoder_padding_mask.any()
1188
+
1189
+ pos_embed = self.embed_positions(new_arange(input_ids))
1190
+ x, encoder_embedding = self.forward_embedding(
1191
+ input_ids, image_embed, image_embed_2, token_embeddings,
1192
+ pos_embed, image_pos_embed, image_pos_embed_2
1193
+ )
1194
+
1195
+ # account for padding while computing the representation
1196
+ if has_pads:
1197
+ x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
1198
+
1199
+ pos_embed = self.pos_ln(pos_embed)
1200
+ if patch_images is not None:
1201
+ image_pos_embed = self.image_pos_ln(image_pos_embed)
1202
+ pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
1203
+ if patch_images_2 is not None:
1204
+ image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
1205
+ pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
1206
+
1207
+ pos_q = self.pos_q_linear(pos_embed).view(
1208
+ x.size(0), x.size(1), self.num_attention_heads, -1
1209
+ ).transpose(1, 2) * self.pos_scaling
1210
+ pos_k = self.pos_k_linear(pos_embed).view(
1211
+ x.size(0), x.size(1), self.num_attention_heads, -1
1212
+ ).transpose(1, 2)
1213
+ abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
1214
+
1215
+ # expand attention_mask
1216
+ if has_pads:
1217
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1218
+ attention_mask = _expand_mask(~encoder_padding_mask, dtype=x.dtype)
1219
+
1220
+ encoder_states = () if output_hidden_states else None
1221
+ all_attentions = () if output_attentions else None
1222
+
1223
+ # encoder layers
1224
+ for idx, layer in enumerate(self.layers):
1225
+ if output_hidden_states:
1226
+ encoder_states += (x,)
1227
+ self_attn_bias = abs_pos_bias.clone()
1228
+ self_attn_bias[:, :, -input_ids.size(1):, -input_ids.size(1):] += self.get_rel_pos_bias(input_ids, idx)
1229
+ if patch_images_2 is not None:
1230
+ self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
1231
+ self.get_image_rel_pos_bias(image_position_ids_2, idx)
1232
+ self_attn_bias[:, :, image_num_patches_2:image_num_patches_2 + image_num_patches,
1233
+ image_num_patches_2:image_num_patches_2 + image_num_patches] += \
1234
+ self.get_image_rel_pos_bias(image_position_ids, idx)
1235
+ elif patch_images is not None:
1236
+ self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \
1237
+ self.get_image_rel_pos_bias(image_position_ids, idx)
1238
+ self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1))
1239
+
1240
+ hidden_outputs = layer(x, attention_mask if has_pads else None, attn_bias=self_attn_bias, output_attentions=output_attentions)
1241
+ x = hidden_outputs[0]
1242
+
1243
+ if output_attentions:
1244
+ attention = hidden_outputs[1]
1245
+ all_attentions = all_attentions + (attention,)
1246
+
1247
+ if output_hidden_states:
1248
+ encoder_states += (x,)
1249
+
1250
+ if self.layer_norm is not None:
1251
+ x = self.layer_norm(x)
1252
+
1253
+ return TiOEncoderOutput(
1254
+ last_hidden_state=x,
1255
+ padding_mask=encoder_padding_mask,
1256
+ hidden_states=encoder_states,
1257
+ attentions=all_attentions,
1258
+ position_embedding=pos_embed,
1259
+ )
1260
+
1261
+
1262
+ class TiODecoder(TiOPreTrainedModel):
1263
+ r"""
1264
+ TiO decoder consisting of layers of [`TiODecoderLayer`]
1265
+
1266
+ Args:
1267
+ config: TiOConfig
1268
+ embed_tokens (`nn.Embedding`, *optional*): output embedding
1269
+ """
1270
+
1271
+ def __init__(self, config: TiOConfig, embed_tokens: Optional[nn.Embedding] = None, output_projection=None):
1272
+ super().__init__(config)
1273
+ self.dropout = nn.Dropout(config.dropout)
1274
+ self.decoder_layerdrop = config.decoder_layerdrop
1275
+ self.padding_idx = config.pad_token_id
1276
+ self.max_target_positions = config.max_position_embeddings
1277
+ self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
1278
+
1279
+ self._future_mask = torch.empty(0)
1280
+ self.share_input_output_embed = config.share_decoder_input_output_embed
1281
+ self.num_attention_heads = config.decoder_attention_heads
1282
+
1283
+ if embed_tokens is not None:
1284
+ self.embed_tokens = embed_tokens
1285
+ else:
1286
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
1287
+
1288
+ self.embed_dim = config.d_model
1289
+ self.output_embed_dim = config.d_model
1290
+
1291
+ self.layers = nn.ModuleList([TiODecoderLayer(config) for _ in range(config.decoder_layers)])
1292
+ if config.layernorm_embedding:
1293
+ self.layernorm_embedding = LayerNorm(self.embed_dim)
1294
+ else:
1295
+ self.layernorm_embedding = None
1296
+
1297
+ self.window_size = config.code_image_size // 8
1298
+
1299
+ self.embed_positions = Embedding(self.max_target_positions + 2, self.embed_dim)
1300
+ self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, self.embed_dim)
1301
+ self.pos_ln = LayerNorm(self.embed_dim)
1302
+ self.image_pos_ln = LayerNorm(self.embed_dim)
1303
+ self.pos_scaling = float(self.embed_dim / self.num_attention_heads * config.attn_scale_factor) ** -0.5
1304
+ self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
1305
+ self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
1306
+ self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
1307
+ self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
1308
+
1309
+ if config.code_layernorm_embedding:
1310
+ self.code_layernorm_embedding = LayerNorm(self.embed_dim)
1311
+ else:
1312
+ self.code_layernorm_embedding = None
1313
+
1314
+ if self.decoder_layerdrop > 0.0:
1315
+ self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
1316
+ else:
1317
+ self.layers = nn.ModuleList([])
1318
+
1319
+ dpr = [x.item() for x in torch.linspace(0, config.decoder_drop_path_rate, config.decoder_layers)]
1320
+ self.layers.extend([TiODecoderLayer(config, drop_path_rate=dpr[i]) for i in range(config.decoder_layers)])
1321
+ self.num_layers = len(self.layers)
1322
+
1323
+ if config.decoder_normalize_before:
1324
+ self.layer_norm = LayerNorm(self.embed_dim)
1325
+ else:
1326
+ self.layer_norm = None
1327
+
1328
+ self.adaptive_softmax = None
1329
+ self.output_projection = output_projection
1330
+ if self.output_projection is None:
1331
+ self.build_output_projection(config)
1332
+
1333
+ self.token_bucket_size = config.token_bucket_size
1334
+ token_num_rel_dis = 2 * config.token_bucket_size - 1
1335
+ token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
1336
+ self.token_rel_pos_table_list = nn.ModuleList(
1337
+ [
1338
+ Embedding(token_num_rel_dis, self.num_attention_heads, zero_init=True)
1339
+ for _ in range(config.decoder_layers)
1340
+ ]
1341
+ )
1342
+
1343
+ self.image_bucket_size = config.image_bucket_size
1344
+ image_num_rel_dis = (2 * config.image_bucket_size - 1) * (2 * config.image_bucket_size - 1) + 3
1345
+ image_rp_bucket = make_image_bucket_position(config.image_bucket_size, image_num_rel_dis)
1346
+ image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
1347
+ torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1
1348
+ image_position_idx = torch.cat([torch.tensor([0]), image_position_idx.view(-1)])
1349
+ image_position_idx = torch.cat([image_position_idx, torch.tensor([1024] * 768)])
1350
+ self.image_rel_pos_table_list = nn.ModuleList(
1351
+ [
1352
+ Embedding(image_num_rel_dis, self.num_attention_heads, zero_init=True)
1353
+ for _ in range(config.decoder_layers)
1354
+ ]
1355
+ )
1356
+
1357
+ self.register_buffer("token_rp_bucket", token_rp_bucket)
1358
+ self.register_buffer("image_rp_bucket", image_rp_bucket)
1359
+ self.register_buffer("image_position_idx", image_position_idx)
1360
+ self.entangle_position_embedding = config.entangle_position_embedding
1361
+
1362
+ self.gradient_checkpointing = False
1363
+ # Initialize weights and apply final processing
1364
+ self.post_init()
1365
+
1366
+ def build_output_projection(self, config):
1367
+ if self.share_input_output_embed:
1368
+ self.output_projection = nn.Linear(
1369
+ self.embed_tokens.weight.shape[1],
1370
+ self.embed_tokens.weight.shape[0],
1371
+ bias=False,
1372
+ )
1373
+ self.output_projection.weight = self.embed_tokens.weight
1374
+ else:
1375
+ self.output_projection = nn.Linear(
1376
+ self.output_embed_dim, config.vocab_size, bias=False
1377
+ )
1378
+ nn.init.normal_(self.output_projection.weight, mean=0, std=self.output_embed_dim**-0.5)
1379
+
1380
+ def get_rel_pos_bias(self, x, idx):
1381
+ r"""
1382
+ Get the relative positional bias of the text, for attention.
1383
+ """
1384
+
1385
+ seq_len = x.size(1)
1386
+ rp_bucket = self.token_rp_bucket[:seq_len, :seq_len]
1387
+ values = F.embedding(rp_bucket, self.token_rel_pos_table_list[idx].weight)
1388
+ values = values.permute([2, 0, 1])
1389
+ return values.contiguous()
1390
+
1391
+ def get_image_rel_pos_bias(self, x, idx):
1392
+ r"""
1393
+ Get the relative positional bias of the image, for attention.
1394
+ """
1395
+
1396
+ seq_len = x.size(1)
1397
+ image_position_idx = self.image_position_idx[:seq_len]
1398
+ rp_bucket = self.image_rp_bucket[image_position_idx][:, image_position_idx]
1399
+ values = F.embedding(rp_bucket, self.image_rel_pos_table_list[idx].weight)
1400
+ values = values.permute(2, 0, 1)
1401
+ return values
1402
+
1403
+ def get_pos_info(self, tgt_pos_embed, src_pos_embed=None, use_image=False):
1404
+ r"""
1405
+ Get the positional information.
1406
+
1407
+ Args:
1408
+ tgt_pos_embed (`torch.FloatTensor` of shape `(bsz, tgt_len, embed_dim)`):
1409
+ the target-side positional embeddings.
1410
+ src_pos_embed (`torch.FloatTensor` of shape `(bsz, src_len, embed_dim)`, *optional*):
1411
+ the source-side positional embeddings.
1412
+ use_image (`bool`): whether to use image.
1413
+
1414
+ Returns:
1415
+ abs_pos_bias (`torch.FloatTensor` of shape `(bsz, src_len, tgt_len, src_len)`):
1416
+ absolute positional bias for attention.
1417
+ """
1418
+
1419
+ batch_size = tgt_pos_embed.size(0)
1420
+ tgt_len = tgt_pos_embed.size(1)
1421
+ tgt_pos_embed = self.image_pos_ln(tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
1422
+
1423
+ if src_pos_embed is not None:
1424
+ src_len = src_pos_embed.size(1)
1425
+ pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
1426
+ batch_size, tgt_len, self.num_attention_heads, -1
1427
+ ).transpose(1, 2) * self.pos_scaling
1428
+ pos_k = self.cross_pos_k_linear(src_pos_embed).view(
1429
+ batch_size, src_len, self.num_attention_heads, -1
1430
+ ).transpose(1, 2)
1431
+ else:
1432
+ src_len = tgt_pos_embed.size(1)
1433
+ pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
1434
+ batch_size, tgt_len, self.num_attention_heads, -1
1435
+ ).transpose(1, 2) * self.pos_scaling
1436
+ pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
1437
+ batch_size, src_len, self.num_attention_heads, -1
1438
+ ).transpose(1, 2)
1439
+ abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
1440
+
1441
+ return abs_pos_bias
1442
+
1443
+ def get_input_embeddings(self):
1444
+ r"""
1445
+ Get the input embeddings
1446
+ """
1447
+ return self.embed_tokens
1448
+
1449
+ def set_input_embeddings(self, value):
1450
+ r"""
1451
+ Set the weights of the embeddings with the given tensor.
1452
+ """
1453
+ self.embed_tokens = value
1454
+
1455
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, dtype, past_key_values_length):
1456
+ r"""
1457
+ Create causal mask for unidirectional decoding.
1458
+ [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1459
+ """
1460
+ combined_attention_mask = None
1461
+ if input_shape[-1] > 1:
1462
+ combined_attention_mask = _make_causal_mask(
1463
+ input_shape, dtype, past_key_values_length=past_key_values_length
1464
+ ).to(self.device)
1465
+
1466
+ if attention_mask is not None:
1467
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1468
+ expanded_attn_mask = _expand_mask(attention_mask, dtype, tgt_len=input_shape[-1])
1469
+ combined_attention_mask = (
1470
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
1471
+ )
1472
+
1473
+ return combined_attention_mask
1474
+
1475
+ def max_positions(self):
1476
+ """Maximum output length supported by the decoder."""
1477
+ if self.embed_positions is None:
1478
+ return self.max_target_positions
1479
+ return self.max_target_positions
1480
+
1481
+ def get_normalized_probs(
1482
+ self,
1483
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
1484
+ log_probs: bool,
1485
+ sample: Optional[Dict[str, Tensor]] = None,
1486
+ ):
1487
+ """Get normalized probabilities (or log probs) from a net's output."""
1488
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
1489
+
1490
+ def get_normalized_probs_scriptable(
1491
+ self,
1492
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
1493
+ log_probs: bool,
1494
+ sample: Optional[Dict[str, Tensor]] = None,
1495
+ ):
1496
+ """Get normalized probabilities (or log probs) from a net's output."""
1497
+
1498
+ if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
1499
+ if sample is not None:
1500
+ assert "target" in sample
1501
+ target = sample["target"]
1502
+ else:
1503
+ target = None
1504
+ out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
1505
+ return out.exp_() if not log_probs else out
1506
+
1507
+ logits = net_output[0]
1508
+ if log_probs:
1509
+ return F.log_softmax(logits, dim=-1)
1510
+ else:
1511
+ return F.softmax(logits, dim=-1)
1512
+
1513
+ def reorder_incremental_state_scripting(
1514
+ self,
1515
+ # incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
1516
+ past_key_values: Optional[torch.Tensor],
1517
+ new_order: Tensor,
1518
+ ):
1519
+ """Main entry point for reordering the incremental state.
1520
+
1521
+ Due to limitations in TorchScript, we call this function in
1522
+ :class:`fairseq.sequence_generator.SequenceGenerator` instead of
1523
+ calling :func:`reorder_incremental_state` directly.
1524
+ """
1525
+ input_buffer = past_key_values
1526
+ new_past_key_values = []
1527
+ if input_buffer is not None:
1528
+ for input_buffer_k in input_buffer:
1529
+ new_input_buffer_k = []
1530
+ for input in input_buffer_k:
1531
+ if input is None:
1532
+ input = None
1533
+ else:
1534
+ input = input.index_select(0, new_order)
1535
+ new_input_buffer_k.append(input)
1536
+ new_past_key_values.append(new_input_buffer_k)
1537
+ return new_past_key_values
1538
+
1539
+ def forward(
1540
+ self,
1541
+ input_ids: torch.Tensor = None,
1542
+ attention_mask: torch.Tensor = None,
1543
+ encoder_hidden_states: torch.Tensor = None,
1544
+ encoder_attention_mask: torch.Tensor = None,
1545
+ code_masks: Optional[torch.Tensor] = None,
1546
+ src_pos_embed: torch.Tensor = None,
1547
+ past_key_values: Optional[torch.Tensor] = None,
1548
+ use_cache: bool = False,
1549
+ output_attentions: bool = False,
1550
+ output_hidden_states: bool = False,
1551
+ ):
1552
+ r"""
1553
+ Args:
1554
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary.
1555
+ attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): mask to avoid attention on padding tokens.
1556
+ encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden state of the encoder.
1557
+ encoder_attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): the padding mask of the source side.
1558
+ code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation.
1559
+ src_pos_embed (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the positional embeddings of the source side.
1560
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed):
1561
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1562
+ shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of
1563
+ shape `(bsz, num_heads, src_len, head_size)`.
1564
+ use_cache (`bool`): whether to use cache for faster inference.
1565
+ output_attentions (`bool`): whether to output attention weights.
1566
+ output_hidden_states (`bool`): whether to output hidden states.
1567
+
1568
+ Returns:
1569
+ BaseModelOutputWithPastAndCrossAttentions or a plain tuple:
1570
+ last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last hidden states.
1571
+ past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference.
1572
+ hidden_states (`tuple(torch.FloatTensor)`): hidden states of all layers.
1573
+ attentions (`tuple(torch.FloatTensor)): self attention weights of all layers.
1574
+ cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers.
1575
+ """
1576
+
1577
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1578
+ output_hidden_states = (
1579
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1580
+ )
1581
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1582
+
1583
+ if past_key_values is not None and len(past_key_values)>0:
1584
+ size = past_key_values[0][0].size()
1585
+ bsz, tgt_len = size[0], size[-2] + 1
1586
+ token_position_idx = torch.arange(tgt_len, device=input_ids.device).expand([bsz, tgt_len]).contiguous()
1587
+ else:
1588
+ bsz, tgt_len = input_ids.shape
1589
+ token_position_idx = new_arange(input_ids)
1590
+ tgt_pos_embed = self.embed_positions(token_position_idx)
1591
+ if code_masks is not None and torch.any(code_masks):
1592
+ image_position_idx = self.image_position_idx[:input_ids.size(1)].unsqueeze(0).expand(bsz, tgt_len)
1593
+ tgt_pos_embed[code_masks] = self.embed_image_positions(image_position_idx)[code_masks]
1594
+
1595
+ # self attn position bias
1596
+ self_abs_pos_bias = self.get_pos_info(tgt_pos_embed, use_image=False)
1597
+ if code_masks is not None and torch.any(code_masks):
1598
+ self_image_abs_pos_bias = self.get_pos_info(tgt_pos_embed, use_image=True)
1599
+ self_abs_pos_bias[code_masks] = self_image_abs_pos_bias[code_masks]
1600
+ # cross attn position bias
1601
+ cross_abs_pos_bias = self.get_pos_info(tgt_pos_embed, src_pos_embed=src_pos_embed)
1602
+ if code_masks is not None and torch.any(code_masks):
1603
+ cross_image_abs_pos_bias = self.get_pos_info(tgt_pos_embed, src_pos_embed=src_pos_embed, use_image=True)
1604
+ cross_abs_pos_bias[code_masks] = cross_image_abs_pos_bias[code_masks]
1605
+ cross_abs_pos_bias = cross_abs_pos_bias.reshape(-1, *cross_abs_pos_bias.size()[-2:])
1606
+
1607
+ all_prev_output_tokens = input_ids.clone()
1608
+ if past_key_values is not None and len(past_key_values)>0:
1609
+ input_ids = input_ids[:, -1:]
1610
+ cross_abs_pos_bias = cross_abs_pos_bias[:, -1:, :]
1611
+ tgt_pos_embed = tgt_pos_embed[:, -1:, :]
1612
+
1613
+ # embed tokens and positions
1614
+ x = self.embed_scale * self.embed_tokens(input_ids)
1615
+
1616
+
1617
+ if self.entangle_position_embedding and not self.disable_entangle:
1618
+ x += tgt_pos_embed
1619
+
1620
+ if self.layernorm_embedding is not None:
1621
+ if code_masks is None or not code_masks.any() or not self.code_layernorm_embedding:
1622
+ x = self.layernorm_embedding(x)
1623
+ elif code_masks is not None and code_masks.all():
1624
+ x = self.code_layernorm_embedding(x)
1625
+ else:
1626
+ x[~code_masks] = self.layernorm_embedding(x[~code_masks])
1627
+ x[code_masks] = self.code_layernorm_embedding(x[code_masks])
1628
+
1629
+ hidden_states = self.dropout(x)
1630
+
1631
+ # past_key_values_length
1632
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None and len(past_key_values)>0 else 0
1633
+
1634
+ shape, dtype = input_ids.shape, hidden_states.dtype
1635
+ attention_mask = self._prepare_decoder_attention_mask(attention_mask, shape, dtype, past_key_values_length)
1636
+
1637
+ # decoder layers
1638
+ all_hidden_states = () if output_hidden_states else None
1639
+ all_self_attns = () if output_attentions else None
1640
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1641
+ next_decoder_cache = () if use_cache else None
1642
+
1643
+ # decoder layers
1644
+ for idx, layer in enumerate(self.layers):
1645
+ # add hidden states from the last decoder layer
1646
+ if output_hidden_states:
1647
+ all_hidden_states += (hidden_states,)
1648
+
1649
+ past_key_value = past_key_values[idx] if past_key_values is not None and len(past_key_values)>0 else None
1650
+
1651
+ self_attn_bias = self_abs_pos_bias.clone()
1652
+ if code_masks is None or not code_masks.any():
1653
+ self_attn_bias += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1654
+ elif code_masks is not None and code_masks.all():
1655
+ self_attn_bias += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1656
+ else:
1657
+ self_attn_bias[~code_masks] += self.get_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1658
+ self_attn_bias[code_masks] += self.get_image_rel_pos_bias(all_prev_output_tokens, idx).unsqueeze(0)
1659
+ self_attn_bias = self_attn_bias.reshape(-1, *self_attn_bias.size()[-2:])
1660
+ if past_key_value is not None and len(past_key_values)>0 :
1661
+ self_attn_bias = self_attn_bias[:, -1:, :]
1662
+
1663
+ layer_outputs = layer(
1664
+ hidden_states,
1665
+ attention_mask=attention_mask,
1666
+ encoder_hidden_states=encoder_hidden_states,
1667
+ encoder_attention_mask=encoder_attention_mask,
1668
+ past_key_value=past_key_value,
1669
+ output_attentions=output_attentions,
1670
+ use_cache=use_cache,
1671
+ self_attn_bias=self_attn_bias,
1672
+ cross_attn_bias=cross_abs_pos_bias,
1673
+ )
1674
+ hidden_states = layer_outputs[0]
1675
+
1676
+ if use_cache:
1677
+ next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1678
+
1679
+ if output_attentions:
1680
+ all_self_attns += (layer_outputs[1],)
1681
+
1682
+ if encoder_hidden_states is not None:
1683
+ all_cross_attentions += (layer_outputs[2],)
1684
+
1685
+ # add hidden states from the last decoder layer
1686
+ if output_hidden_states:
1687
+ all_hidden_states += (hidden_states,)
1688
+
1689
+ next_cache = next_decoder_cache if use_cache else None
1690
+
1691
+ if self.layer_norm is not None:
1692
+ hidden_states = self.layer_norm(hidden_states)
1693
+
1694
+ if self.output_projection is not None:
1695
+ hidden_states = self.output_projection(hidden_states)
1696
+
1697
+ return BaseModelOutputWithPastAndCrossAttentions(
1698
+ last_hidden_state=hidden_states,
1699
+ past_key_values=next_cache,
1700
+ hidden_states=all_hidden_states,
1701
+ attentions=all_self_attns,
1702
+ cross_attentions=all_cross_attentions,
1703
+ )
1704
+
1705
+
1706
+ # @add_start_docstrings(
1707
+ # "The bare TiO Model outputting raw hidden-states without any specific head on top.",
1708
+ # TiO_START_DOCSTRING,
1709
+ # )
1710
+ class TiOModel(TiOPreTrainedModel):
1711
+ r"""
1712
+ The TiO model built with an encoder and a decoder only, without any classification head.
1713
+
1714
+ Args:
1715
+ config (TiOConfig): TiO configuration.
1716
+ """
1717
+
1718
+ def __init__(self, config: TiOConfig, **kwargs):
1719
+ super().__init__(config)
1720
+ self.disable_entangle = getattr(kwargs,'disable_entangle',False)
1721
+
1722
+ self.padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1723
+ shared = nn.Embedding(vocab_size, config.d_model, self.padding_idx)
1724
+
1725
+ self.encoder = TiOEncoder(config, shared)
1726
+ self.decoder = TiODecoder(config, shared)
1727
+
1728
+ # Initialize weights and apply final processing
1729
+ self.post_init()
1730
+
1731
+ def get_input_embeddings(self):
1732
+ r"""
1733
+ Retrieve input embeddings.
1734
+ """
1735
+ return self.encoder.get_input_embeddings()
1736
+
1737
+ def set_input_embeddings(self, value):
1738
+ r"""
1739
+ Set values for input embeddings
1740
+ """
1741
+ shared = value
1742
+ self.encoder.embed_tokens = shared
1743
+ self.decoder.embed_tokens = shared
1744
+
1745
+ def get_encoder(self):
1746
+ r"""
1747
+ Retrieve the encoder
1748
+ """
1749
+ return self.encoder
1750
+
1751
+ def get_decoder(self):
1752
+ r"""
1753
+ Retrieve the decoder
1754
+ """
1755
+ return self.decoder
1756
+
1757
+ # @add_start_docstrings_to_model_forward(TiO_INPUTS_DOCSTRING)
1758
+ # @add_code_sample_docstrings(
1759
+ # processor_class=_TOKENIZER_FOR_DOC,
1760
+ # checkpoint=_CHECKPOINT_FOR_DOC,
1761
+ # output_type=Seq2SeqModelOutput,
1762
+ # config_class=_CONFIG_FOR_DOC,
1763
+ # )
1764
+
1765
+ def max_decoder_positions(self):
1766
+ """Maximum length supported by the decoder."""
1767
+ return self.decoder.max_positions()
1768
+
1769
+ def get_normalized_probs(
1770
+ self,
1771
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
1772
+ log_probs: bool,
1773
+ sample: Optional[Dict[str, Tensor]] = None,
1774
+ ):
1775
+ """Get normalized probabilities (or log probs) from a net's output."""
1776
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
1777
+
1778
+
1779
+ def get_normalized_probs_scriptable(
1780
+ self,
1781
+ net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
1782
+ log_probs: bool,
1783
+ sample: Optional[Dict[str, Tensor]] = None,
1784
+ ):
1785
+ """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel"""
1786
+ if hasattr(self, "decoder"):
1787
+ return self.decoder.get_normalized_probs(net_output, log_probs, sample)
1788
+ elif torch.is_tensor(net_output):
1789
+ # syntactic sugar for simple models which don't have a decoder
1790
+ # (e.g., the classification tutorial)
1791
+ logits = net_output.float()
1792
+ if log_probs:
1793
+ return F.log_softmax(logits, dim=-1)
1794
+ else:
1795
+ return F.softmax(logits, dim=-1)
1796
+ raise NotImplementedError
1797
+
1798
+ def forward(
1799
+ self,
1800
+ input_ids=None,
1801
+ patch_images=None,
1802
+ patch_images_2=None,
1803
+ patch_masks=None,
1804
+ token_embeddings=None,
1805
+ sample_patch_num=None,
1806
+ decoder_input_ids=None,
1807
+ code_masks=None,
1808
+ attention_mask=None,
1809
+ encoder_outputs=None,
1810
+ past_key_values=None,
1811
+ use_cache=False,
1812
+ output_attentions=False,
1813
+ output_hidden_states=False,
1814
+ labels=None,
1815
+ return_dict=False
1816
+ ):
1817
+ r"""
1818
+ Args:
1819
+ input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`):
1820
+ indices of input sequence tokens in the vocabular, and padding will be ignored by default;
1821
+
1822
+ indices can be obtained using [`~TiOTokenizer`].
1823
+
1824
+ patch_images (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
1825
+ the resized image, which are transformed by the default operations.
1826
+ patch_images_2 (`torch.FloatTensor` of shape `(bsz, 3, height, width)`):
1827
+ the second (if it exists) image.
1828
+ patch_masks (`torch.BoolTensor`): the patches to be masked.
1829
+ token_embeddings (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`): token embeddings.
1830
+ sample_patch_num (`int`): the number of patches to sample.
1831
+ decoder_input_ids (`torch.LongTensor` of shape `(bsz, seq_len)`): indices of the sequence in the vocabulary.
1832
+ code_masks (`torch.Tensor` of shape `(bsz, seq_len)`): masks only for code generation.
1833
+ attention_mask (`torch.Tensor` of shape `(bsz, seq_len)`): attention mask for decoding.
1834
+ encoder_outputs (`TiOEncoderOutput`):
1835
+ encoder outputs with hidden states, positional embeddings, and padding masks.
1836
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed):
1837
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1838
+ shape `(bsz, num_heads, tgt_len, head_size)`) and 2 additional tensors of
1839
+ shape `(bsz, num_heads, src_len, head_size)`.
1840
+ use_cache (`bool`): whether to use cache for faster inference.
1841
+ output_attentions (`bool`): whether to output attention weights.
1842
+ output_hidden_states (`bool`): whether to output hidden states.
1843
+ return_dict (`bool`): unused. Keep it for generation only.
1844
+
1845
+ Returns:
1846
+ Seq2SeqLMOutput:
1847
+ logits (`torch.FloatTensor` of shape `(bsz, seq_len, hidden)`): the last decoder hidden states.
1848
+ past_key_values (`tuple(tuple(torch.FloatTensor)): past keys and values for faster inference.
1849
+ decoder_hidden_states (`tuple(torch.FloatTensor)`): the decoder hidden states of all layers.
1850
+ decoder_attentions (`tuple(torch.FloatTensor)): the decoder self attention weights of all layers.
1851
+ cross_attentions (`tuple(torch.FloatTensor)): cross attention weights of all layers.
1852
+ encoder_last_hidden_state (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1853
+ the encoder last hidden state.
1854
+ encoder_hidden_states (`torch.FloatTensor` of shape `(bsz, seq_len, embed_dim)`):
1855
+ the encoder states of all layers including the embeddings.
1856
+ encoder_attentions (`torch.FloatTensor` of shape `(bsz, num_heads, seq_len, seq_len)`):
1857
+ the encoder attention weights of all layers.
1858
+ """
1859
+
1860
+ output_attentions = output_attentions if output_attentions else self.config.output_attentions
1861
+ output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
1862
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1863
+
1864
+ if encoder_outputs is None:
1865
+ encoder_outputs = self.encoder(
1866
+ input_ids=input_ids,
1867
+ patch_images=patch_images,
1868
+ patch_images_2=patch_images_2,
1869
+ patch_masks=patch_masks,
1870
+ output_attentions=output_attentions,
1871
+ output_hidden_states=output_hidden_states,
1872
+ token_embeddings=token_embeddings,
1873
+ sample_patch_num=sample_patch_num,
1874
+ )
1875
+
1876
+ # if decoder_input_ids.eq(self.config.pad_token_id).any():
1877
+ # attention_mask = decoder_input_ids.eq(self.padding_idx)
1878
+
1879
+ encoder_hidden_states = encoder_outputs.last_hidden_state
1880
+ if past_key_values is not None and len(past_key_values)>0:
1881
+ encoder_attention_mask = _expand_mask(
1882
+ ~encoder_outputs.padding_mask, encoder_hidden_states.dtype, decoder_input_ids[:, -1:].shape[-1]
1883
+ )
1884
+ else:
1885
+ encoder_attention_mask = _expand_mask(
1886
+ ~encoder_outputs.padding_mask, encoder_hidden_states.dtype, decoder_input_ids.shape[-1]
1887
+ )
1888
+ src_pos_embed = encoder_outputs.position_embedding
1889
+
1890
+ decoder_outputs = self.decoder(
1891
+ input_ids=decoder_input_ids,
1892
+ attention_mask=attention_mask,
1893
+ encoder_hidden_states=encoder_hidden_states,
1894
+ encoder_attention_mask=encoder_attention_mask,
1895
+ code_masks=code_masks,
1896
+ src_pos_embed=src_pos_embed,
1897
+ past_key_values=past_key_values,
1898
+ use_cache=use_cache,
1899
+ output_attentions=output_attentions,
1900
+ output_hidden_states=output_hidden_states,
1901
+ )
1902
+
1903
+ return Seq2SeqLMOutput(
1904
+ logits=decoder_outputs.last_hidden_state,
1905
+ past_key_values=decoder_outputs.past_key_values,
1906
+ decoder_hidden_states=decoder_outputs.hidden_states,
1907
+ decoder_attentions=decoder_outputs.attentions,
1908
+ cross_attentions=decoder_outputs.cross_attentions,
1909
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1910
+ encoder_hidden_states=encoder_outputs.hidden_states,
1911
+ encoder_attentions=encoder_outputs.attentions,
1912
+ )
1913
+
1914
+ def prepare_inputs_for_generation(
1915
+ self,
1916
+ decoder_input_ids=None,
1917
+ past=None,
1918
+ attention_mask=None,
1919
+ code_masks=None,
1920
+ use_cache=False,
1921
+ encoder_outputs=None,
1922
+ **kwargs
1923
+ ):
1924
+ # if attention_mask is None:
1925
+ attention_mask = decoder_input_ids.new_ones(decoder_input_ids.shape)
1926
+
1927
+ # cut decoder_input_ids if past is used
1928
+ # if past is not None:
1929
+ # decoder_input_ids = decoder_input_ids[:, -1:]
1930
+
1931
+ return {
1932
+ "input_ids": None,
1933
+ "patch_images": None,
1934
+ "patch_images_2": None,
1935
+ "patch_masks": None,
1936
+ "token_embeddings": None,
1937
+ "sample_patch_num": None,
1938
+ "attention_mask": attention_mask,
1939
+ "encoder_outputs": encoder_outputs,
1940
+ "past_key_values": past,
1941
+ "decoder_input_ids": decoder_input_ids,
1942
+ "code_masks": code_masks,
1943
+ "use_cache": use_cache,
1944
+ }
1945
+
1946
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1947
+ return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1948
+
1949
+ def _prepare_encoder_decoder_kwargs_for_generation(
1950
+ self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
1951
+ ):
1952
+ # 1. get encoder
1953
+ encoder = self.get_encoder()
1954
+
1955
+ # 2. prepare encoder args and encoder kwargs from model kwargs
1956
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache", "attention_mask"]
1957
+ encoder_kwargs = {
1958
+ argument: value
1959
+ for argument, value in model_kwargs.items()
1960
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
1961
+ }
1962
+
1963
+ if encoder_kwargs.get("patch_masks") is None:
1964
+ encoder_kwargs["patch_masks"] = torch.ones((len(inputs_tensor), 1), dtype=torch.bool, device=inputs_tensor.device)
1965
+
1966
+ # 3. make sure that encoder returns `ModelOutput`
1967
+ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
1968
+ encoder_kwargs[model_input_name] = inputs_tensor
1969
+ model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
1970
+ model_kwargs["attention_mask"] = None
1971
+
1972
+ return model_kwargs
1973
+
1974
+ @staticmethod
1975
+ def _reorder_cache(past, beam_idx):
1976
+ reordered_past = ()
1977
+ for layer_past in past:
1978
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1979
+ return reordered_past
1980
+
1981
+ @staticmethod
1982
+ def _expand_inputs_for_generation(
1983
+ input_ids: torch.LongTensor,
1984
+ expand_size: int = 1,
1985
+ is_encoder_decoder: bool = False,
1986
+ attention_mask: Optional[torch.LongTensor] = None,
1987
+ encoder_outputs: Optional[ModelOutput] = None,
1988
+ **model_kwargs,
1989
+ ):
1990
+ expanded_return_idx = (
1991
+ torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
1992
+ )
1993
+ input_ids = input_ids.index_select(0, expanded_return_idx)
1994
+
1995
+ if "token_type_ids" in model_kwargs:
1996
+ token_type_ids = model_kwargs["token_type_ids"]
1997
+ model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
1998
+
1999
+ if attention_mask is not None:
2000
+ model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
2001
+
2002
+ if is_encoder_decoder:
2003
+ if encoder_outputs is None:
2004
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
2005
+ encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
2006
+ 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
2007
+ )
2008
+ encoder_outputs["position_embedding"] = encoder_outputs.position_embedding.index_select(
2009
+ 0, expanded_return_idx.to(encoder_outputs.position_embedding.device)
2010
+ )
2011
+ encoder_outputs["padding_mask"] = encoder_outputs.padding_mask.index_select(
2012
+ 0, expanded_return_idx.to(encoder_outputs.padding_mask.device)
2013
+ )
2014
+ model_kwargs["encoder_outputs"] = encoder_outputs
2015
+ return input_ids, model_kwargs
preprocessor_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_pct": 0.875,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "ConvNextImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "resample": 3,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "shortest_edge": 512
21
+ }
22
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9521f3753abfe4f992a38139b2f6a794425f0a1f1a497ab3e23bb1423a58c0a1
3
+ size 4394547219
resnet.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # [Apache-2.0] Copyed from https://github.com/OFA-Sys/OFA
3
+ # Copyright 2022 The OFA-Sys Team. All rights reserved.
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
10
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
11
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
12
+ the original name is misleading as 'Drop Connect' is a.sh different form of dropout in a.sh separate paper...
13
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
14
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a.sh layer name and use
15
+ 'survival rate' as the argument.
16
+ """
17
+ if drop_prob == 0.0 or not training:
18
+ return x
19
+ keep_prob = 1 - drop_prob
20
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
21
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
22
+ random_tensor.floor_() # binarize
23
+ output = x.div(keep_prob) * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
36
+
37
+
38
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
39
+ """3x3 convolution with padding"""
40
+ return nn.Conv2d(
41
+ in_planes,
42
+ out_planes,
43
+ kernel_size=3,
44
+ stride=stride,
45
+ padding=dilation,
46
+ groups=groups,
47
+ bias=False,
48
+ dilation=dilation,
49
+ )
50
+
51
+
52
+ def conv1x1(in_planes, out_planes, stride=1):
53
+ """1x1 convolution"""
54
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
55
+
56
+
57
+ class BasicBlock(nn.Module):
58
+ expansion = 1
59
+
60
+ def __init__(
61
+ self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None
62
+ ):
63
+ super(BasicBlock, self).__init__()
64
+ if norm_layer is None:
65
+ norm_layer = nn.BatchNorm2d
66
+ if groups != 1 or base_width != 64:
67
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
68
+ if dilation > 1:
69
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
70
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
71
+ self.conv1 = conv3x3(inplanes, planes, stride)
72
+ self.bn1 = norm_layer(planes)
73
+ self.relu = nn.ReLU(inplace=True)
74
+ self.conv2 = conv3x3(planes, planes)
75
+ self.bn2 = norm_layer(planes)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+
79
+ def forward(self, x):
80
+ assert False
81
+ identity = x
82
+
83
+ out = self.conv1(x)
84
+ out = self.bn1(out)
85
+ out = self.relu(out)
86
+
87
+ out = self.conv2(out)
88
+ out = self.bn2(out)
89
+
90
+ if self.downsample is not None:
91
+ identity = self.downsample(x)
92
+
93
+ out += identity
94
+ out = self.relu(out)
95
+
96
+ return out
97
+
98
+
99
+ class Bottleneck(nn.Module):
100
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
101
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
102
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
103
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
104
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
105
+
106
+ expansion = 4
107
+
108
+ def __init__(
109
+ self,
110
+ inplanes,
111
+ planes,
112
+ stride=1,
113
+ downsample=None,
114
+ groups=1,
115
+ base_width=64,
116
+ dilation=1,
117
+ norm_layer=None,
118
+ drop_path_rate=0.0,
119
+ ):
120
+ super(Bottleneck, self).__init__()
121
+ if norm_layer is None:
122
+ norm_layer = nn.BatchNorm2d
123
+ width = int(planes * (base_width / 64.0)) * groups
124
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
125
+ self.conv1 = conv1x1(inplanes, width)
126
+ self.bn1 = norm_layer(width)
127
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
128
+ self.bn2 = norm_layer(width)
129
+ self.conv3 = conv1x1(width, planes * self.expansion)
130
+ self.bn3 = norm_layer(planes * self.expansion)
131
+ self.relu = nn.ReLU(inplace=True)
132
+ self.downsample = downsample
133
+ self.stride = stride
134
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
135
+
136
+ def forward(self, x):
137
+ identity = x
138
+
139
+ out = self.conv1(x)
140
+ out = self.bn1(out)
141
+ out = self.relu(out)
142
+
143
+ out = self.conv2(out)
144
+ out = self.bn2(out)
145
+ out = self.relu(out)
146
+
147
+ out = self.conv3(out)
148
+ out = self.bn3(out)
149
+
150
+ if self.downsample is not None:
151
+ identity = self.downsample(x)
152
+
153
+ out = identity + self.drop_path(out)
154
+ out = self.relu(out)
155
+
156
+ return out
157
+
158
+
159
+ class ResNet(nn.Module):
160
+ def __init__(
161
+ self,
162
+ layers,
163
+ zero_init_residual=False,
164
+ groups=1,
165
+ width_per_group=64,
166
+ replace_stride_with_dilation=None,
167
+ norm_layer=None,
168
+ drop_path_rate=0.0,
169
+ ):
170
+ super(ResNet, self).__init__()
171
+ if norm_layer is None:
172
+ norm_layer = nn.BatchNorm2d
173
+ self._norm_layer = norm_layer
174
+
175
+ self.inplanes = 64
176
+ self.dilation = 1
177
+ if replace_stride_with_dilation is None:
178
+ # each element in the tuple indicates if we should replace
179
+ # the 2x2 stride with a dilated convolution instead
180
+ replace_stride_with_dilation = [False, False, False]
181
+ if len(replace_stride_with_dilation) != 3:
182
+ raise ValueError(
183
+ "replace_stride_with_dilation should be None "
184
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
185
+ )
186
+ self.groups = groups
187
+ self.base_width = width_per_group
188
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
189
+ self.bn1 = norm_layer(self.inplanes)
190
+ self.relu = nn.ReLU(inplace=True)
191
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
192
+ self.layer1 = self._make_layer(Bottleneck, 64, layers[0], drop_path_rate=drop_path_rate)
193
+ self.layer2 = self._make_layer(
194
+ Bottleneck, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], drop_path_rate=drop_path_rate
195
+ )
196
+ self.layer3 = self._make_layer(
197
+ Bottleneck, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], drop_path_rate=drop_path_rate
198
+ )
199
+
200
+ for m in self.modules():
201
+ if isinstance(m, nn.Conv2d):
202
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
203
+ elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d, nn.GroupNorm)):
204
+ nn.init.constant_(m.weight, 1)
205
+ nn.init.constant_(m.bias, 0)
206
+
207
+ # Zero-initialize the last BN in each residual branch,
208
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
209
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
210
+ if zero_init_residual:
211
+ for m in self.modules():
212
+ if isinstance(m, Bottleneck):
213
+ nn.init.constant_(m.bn3.weight, 0)
214
+ elif isinstance(m, BasicBlock):
215
+ nn.init.constant_(m.bn2.weight, 0)
216
+
217
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False, drop_path_rate=0.0):
218
+ norm_layer = self._norm_layer
219
+ downsample = None
220
+ previous_dilation = self.dilation
221
+ if dilate:
222
+ self.dilation *= stride
223
+ stride = 1
224
+ if stride != 1 or self.inplanes != planes * block.expansion:
225
+ downsample = nn.Sequential(
226
+ conv1x1(self.inplanes, planes * block.expansion, stride),
227
+ norm_layer(planes * block.expansion),
228
+ )
229
+
230
+ layers = []
231
+ layers.append(
232
+ block(
233
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
234
+ )
235
+ )
236
+ self.inplanes = planes * block.expansion
237
+
238
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, blocks)]
239
+ for i in range(1, blocks):
240
+ layers.append(
241
+ block(
242
+ self.inplanes,
243
+ planes,
244
+ groups=self.groups,
245
+ base_width=self.base_width,
246
+ dilation=self.dilation,
247
+ norm_layer=norm_layer,
248
+ drop_path_rate=dpr[i],
249
+ )
250
+ )
251
+
252
+ return nn.Sequential(*layers)
253
+
254
+ def _forward_impl(self, x):
255
+ # See note [TorchScript super()]
256
+ x = self.conv1(x)
257
+ x = self.bn1(x)
258
+ x = self.relu(x)
259
+ x = self.maxpool(x)
260
+
261
+ x = self.layer1(x)
262
+ x = self.layer2(x)
263
+ x = self.layer3(x)
264
+
265
+ return x
266
+
267
+ def forward(self, x):
268
+ return self._forward_impl(x)
special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": true,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": true,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": true,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "clean_up_tokenization_spaces": true,
12
+ "cls_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "eos_token": {
21
+ "__type": "AddedToken",
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": true,
25
+ "rstrip": false,
26
+ "single_word": false
27
+ },
28
+ "errors": "replace",
29
+ "mask_token": {
30
+ "__type": "AddedToken",
31
+ "content": "<mask>",
32
+ "lstrip": true,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "model_max_length": 1000000000000000019884624838656,
38
+ "pad_token": {
39
+ "__type": "AddedToken",
40
+ "content": "<pad>",
41
+ "lstrip": false,
42
+ "normalized": true,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ },
46
+ "sep_token": {
47
+ "__type": "AddedToken",
48
+ "content": "</s>",
49
+ "lstrip": false,
50
+ "normalized": true,
51
+ "rstrip": false,
52
+ "single_word": false
53
+ },
54
+ "special_tokens_map_file": null,
55
+ "tokenizer_class": "BartTokenizer",
56
+ "truncation_side": "left",
57
+ "unk_token": {
58
+ "__type": "AddedToken",
59
+ "content": "<unk>",
60
+ "lstrip": false,
61
+ "normalized": true,
62
+ "rstrip": false,
63
+ "single_word": false
64
+ }
65
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff