Fizzarolli commited on
Commit
d2a1495
·
verified ·
1 Parent(s): ec9a459

Create modernberg_model.py

Browse files
Files changed (1) hide show
  1. modernberg_model.py +1470 -0
modernberg_model.py ADDED
@@ -0,0 +1,1470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Answer.AI, LightOn, and contributors, and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from contextlib import nullcontext
18
+ from typing import Dict, Literal, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.configuration_utils import PretrainedConfig
28
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutput,
31
+ MaskedLMOutput,
32
+ SequenceClassifierOutput,
33
+ TokenClassifierOutput,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import (
37
+ add_code_sample_docstrings,
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ is_flash_attn_2_available,
41
+ logging,
42
+ )
43
+ from transformers.utils.import_utils import is_triton_available, is_torchdynamo_compiling
44
+ from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
45
+ from transformers.models.modernbert.modular_modernbert import (_pad_modernbert_output, _unpad_modernbert_input, ModernBertEmbeddings, ModernBertMLP, ModernBertUnpaddedRotaryEmbedding, ModernBertEmbeddings)
46
+
47
+
48
+ if is_flash_attn_2_available():
49
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
50
+ from flash_attn.layers.rotary import RotaryEmbedding
51
+ from flash_attn.ops.triton.rotary import apply_rotary
52
+ else:
53
+ RotaryEmbedding = object
54
+
55
+ _CHECKPOINT_FOR_DOC = "answerdotai/ModernBERT-base"
56
+ _CONFIG_FOR_DOC = "ModernBertConfig"
57
+ _MAX_SQRT_GRADIENT = 1000.0
58
+
59
+ logger = logging.get_logger(__name__)
60
+
61
+
62
+ class ModernBergConfig(PretrainedConfig):
63
+ r"""
64
+ This is the configuration class to store the configuration of a [`ModernBergModel`]. It is used to instantiate an ModernBerg
65
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
66
+ defaults will yield a similar configuration to that of the ModernBERT-base.
67
+
68
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
69
+ documentation from [`PretrainedConfig`] for more information.
70
+
71
+ Args:
72
+ vocab_size (`int`, *optional*, defaults to 50368):
73
+ Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
74
+ `inputs_ids` passed when calling [`ModernBertModel`]
75
+ hidden_size (`int`, *optional*, defaults to 768):
76
+ Dimension of the hidden representations.
77
+ intermediate_size (`int`, *optional*, defaults to 1152):
78
+ Dimension of the MLP representations.
79
+ num_hidden_layers (`int`, *optional*, defaults to 22):
80
+ Number of hidden layers in the Transformer decoder.
81
+ num_attention_heads (`int`, *optional*, defaults to 12):
82
+ Number of attention heads for each attention layer in the Transformer decoder.
83
+ lru_width (`int`, *optional*, defaults to 128):
84
+ The dimension of the RG-LRU -- if None, this will be set to `hidden_size`.
85
+ hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
86
+ The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
87
+ if not specified.
88
+ max_position_embeddings (`int`, *optional*, defaults to 8192):
89
+ The maximum sequence length that this model might ever be used with.
90
+ initializer_range (`float`, *optional*, defaults to 0.02):
91
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
92
+ initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
93
+ The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
94
+ norm_eps (`float`, *optional*, defaults to 1e-05):
95
+ The epsilon used by the rms normalization layers.
96
+ norm_bias (`bool`, *optional*, defaults to `False`):
97
+ Whether to use bias in the normalization layers.
98
+ pad_token_id (`int`, *optional*, defaults to 50283):
99
+ Padding token id.
100
+ eos_token_id (`int`, *optional*, defaults to 50282):
101
+ End of stream token id.
102
+ bos_token_id (`int`, *optional*, defaults to 50281):
103
+ Beginning of stream token id.
104
+ cls_token_id (`int`, *optional*, defaults to 50281):
105
+ Classification token id.
106
+ sep_token_id (`int`, *optional*, defaults to 50282):
107
+ Separation token id.
108
+ global_rope_theta (`float`, *optional*, defaults to 160000.0):
109
+ The base period of the global RoPE embeddings.
110
+ attention_bias (`bool`, *optional*, defaults to `False`):
111
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
112
+ attention_dropout (`float`, *optional*, defaults to 0.0):
113
+ The dropout ratio for the attention probabilities.
114
+ global_temporal_every_n_layers (`int`, *optional*, defaults to 3):
115
+ The number of layers between global temporal mixing layers.
116
+ local_attention (`int`, *optional*, defaults to 128):
117
+ The window size for local attention.
118
+ local_rope_theta (`float`, *optional*, defaults to 10000.0):
119
+ The base period of the local RoPE embeddings.
120
+ embedding_dropout (`float`, *optional*, defaults to 0.0):
121
+ The dropout ratio for the embeddings.
122
+ mlp_bias (`bool`, *optional*, defaults to `False`):
123
+ Whether to use bias in the MLP layers.
124
+ mlp_dropout (`float`, *optional*, defaults to 0.0):
125
+ The dropout ratio for the MLP layers.
126
+ decoder_bias (`bool`, *optional*, defaults to `True`):
127
+ Whether to use bias in the decoder layers.
128
+ classifier_pooling (`str`, *optional*, defaults to `"cls"`):
129
+ The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
130
+ CLS token doesn't attend to all tokens on long sequences.
131
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
132
+ The dropout ratio for the classifier.
133
+ classifier_bias (`bool`, *optional*, defaults to `False`):
134
+ Whether to use bias in the classifier.
135
+ classifier_activation (`str`, *optional*, defaults to `"gelu"`):
136
+ The activation function for the classifier.
137
+ deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
138
+ Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
139
+ sparse_prediction (`bool`, *optional*, defaults to `False`):
140
+ Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
141
+ sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
142
+ The index to ignore for the sparse prediction.
143
+ reference_compile (`bool`, *optional*):
144
+ Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
145
+ the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
146
+ shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
147
+ be faster in some scenarios.
148
+ repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
149
+ When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
150
+ applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
151
+
152
+ Examples:
153
+
154
+ ```python
155
+ >>> from transformers import ModernBertModel, ModernBertConfig
156
+
157
+ >>> # Initializing a ModernBert style configuration
158
+ >>> configuration = ModernBertConfig()
159
+
160
+ >>> # Initializing a model from the modernbert-base style configuration
161
+ >>> model = ModernBertModel(configuration)
162
+
163
+ >>> # Accessing the model configuration
164
+ >>> configuration = model.config
165
+ ```"""
166
+
167
+ model_type = "modernbert"
168
+ keys_to_ignore_at_inference = ["past_key_values"]
169
+
170
+ def __init__(
171
+ self,
172
+ vocab_size=50368,
173
+ hidden_size=768,
174
+ intermediate_size=1152,
175
+ num_hidden_layers=22,
176
+ num_attention_heads=12,
177
+ lru_width=1152,
178
+ conv1d_width=4,
179
+ hidden_activation="gelu",
180
+ max_position_embeddings=8192,
181
+ initializer_range=0.02,
182
+ initializer_cutoff_factor=2.0,
183
+ norm_eps=1e-5,
184
+ norm_bias=False,
185
+ pad_token_id=50283,
186
+ eos_token_id=50282,
187
+ bos_token_id=50281,
188
+ cls_token_id=50281,
189
+ sep_token_id=50282,
190
+ global_rope_theta=160000.0,
191
+ attention_bias=False,
192
+ attention_dropout=0.0,
193
+ global_temporal_every_n_layers=3,
194
+ local_attention=128,
195
+ local_rope_theta=10000.0,
196
+ embedding_dropout=0.0,
197
+ mlp_bias=False,
198
+ mlp_dropout=0.0,
199
+ decoder_bias=True,
200
+ classifier_pooling: Literal["cls", "mean"] = "cls",
201
+ classifier_dropout=0.0,
202
+ classifier_bias=False,
203
+ classifier_activation="gelu",
204
+ deterministic_flash_attn=False,
205
+ sparse_prediction=False,
206
+ sparse_pred_ignore_index=-100,
207
+ reference_compile=None,
208
+ repad_logits_with_grad=False,
209
+ **kwargs,
210
+ ):
211
+ super().__init__(
212
+ pad_token_id=pad_token_id,
213
+ bos_token_id=bos_token_id,
214
+ eos_token_id=eos_token_id,
215
+ cls_token_id=cls_token_id,
216
+ sep_token_id=sep_token_id,
217
+ **kwargs,
218
+ )
219
+ self.vocab_size = vocab_size
220
+ self.max_position_embeddings = max_position_embeddings
221
+ self.hidden_size = hidden_size
222
+ self.intermediate_size = intermediate_size
223
+ self.num_hidden_layers = num_hidden_layers
224
+ self.num_attention_heads = num_attention_heads
225
+ self.lru_width = lru_width
226
+ self.conv1d_width = conv1d_width
227
+ self.initializer_range = initializer_range
228
+ self.initializer_cutoff_factor = initializer_cutoff_factor
229
+ self.norm_eps = norm_eps
230
+ self.norm_bias = norm_bias
231
+ self.global_rope_theta = global_rope_theta
232
+ self.attention_bias = attention_bias
233
+ self.attention_dropout = attention_dropout
234
+ self.hidden_activation = hidden_activation
235
+ self.global_temporal_every_n_layers = global_temporal_every_n_layers
236
+ self.local_attention = local_attention
237
+ self.local_rope_theta = local_rope_theta
238
+ self.embedding_dropout = embedding_dropout
239
+ self.mlp_bias = mlp_bias
240
+ self.mlp_dropout = mlp_dropout
241
+ self.decoder_bias = decoder_bias
242
+ self.classifier_pooling = classifier_pooling
243
+ self.classifier_dropout = classifier_dropout
244
+ self.classifier_bias = classifier_bias
245
+ self.classifier_activation = classifier_activation
246
+ self.deterministic_flash_attn = deterministic_flash_attn
247
+ self.sparse_prediction = sparse_prediction
248
+ self.sparse_pred_ignore_index = sparse_pred_ignore_index
249
+ self.reference_compile = reference_compile
250
+ self.repad_logits_with_grad = repad_logits_with_grad
251
+
252
+ if self.classifier_pooling not in ["cls", "mean"]:
253
+ raise ValueError(
254
+ f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
255
+ )
256
+
257
+ class SqrtBoundDerivative(torch.autograd.Function):
258
+ """Computes a square root with a gradient clipped at `_MAX_SQRT_GRADIENT`."""
259
+
260
+ @staticmethod
261
+ def forward(ctx, x: torch.Tensor) -> torch.Tensor:
262
+ """The forward pass, which is a normal `sqrt`."""
263
+ ctx.save_for_backward(x)
264
+ return torch.sqrt(x)
265
+
266
+ @staticmethod
267
+ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
268
+ """The backward pass, which clips the `sqrt` gradient."""
269
+ (x,) = ctx.saved_tensors
270
+ clipped_x_times_4 = torch.clip(4.0 * x, min=1 / (_MAX_SQRT_GRADIENT**2))
271
+ return grad_output / torch.sqrt(clipped_x_times_4)
272
+
273
+ class GriffinRglru(nn.Module):
274
+ """A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
275
+
276
+ def __init__(self, config: ModernBergConfig):
277
+ super().__init__()
278
+ self.num_attention_heads = config.num_attention_heads
279
+ self.block_width = config.lru_width // self.num_attention_heads
280
+
281
+ self.recurrent_param = nn.Parameter(torch.empty([config.lru_width]))
282
+ self.input_gate_weight = nn.Parameter(
283
+ torch.empty([self.num_attention_heads, self.block_width, self.block_width])
284
+ )
285
+ self.input_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width]))
286
+
287
+ self.recurrent_gate_weight = nn.Parameter(
288
+ torch.empty([self.num_attention_heads, self.block_width, self.block_width])
289
+ )
290
+ self.recurrent_gate_bias = nn.Parameter(torch.empty([self.num_attention_heads, self.block_width]))
291
+ self.recurrent_states = None
292
+
293
+ def forward(
294
+ self,
295
+ activations: torch.Tensor,
296
+ position_ids: torch.Tensor,
297
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
298
+ batch_size, seq_len, lru_width = activations.shape
299
+ reset = position_ids[:, :, None] == 0
300
+
301
+ reshape_act = activations.reshape(batch_size * seq_len, self.num_attention_heads, self.block_width)
302
+ reshape_act = reshape_act.permute(1, 0, 2)
303
+
304
+ res = torch.baddbmm(self.input_gate_bias[:, None, :], reshape_act, self.input_gate_weight)
305
+ input_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width))
306
+
307
+ res = torch.baddbmm(self.recurrent_gate_bias[:, None, :], reshape_act, self.recurrent_gate_weight)
308
+ recurrent_gate = torch.sigmoid(res.transpose(0, 1).reshape(batch_size, seq_len, lru_width))
309
+
310
+ # Compute the parameter `A` of the recurrence.
311
+ log_recurrent_gate = -8.0 * recurrent_gate * nn.functional.softplus(self.recurrent_param)
312
+ recurrent_gate = torch.exp(log_recurrent_gate)
313
+ a_square = torch.exp(2 * log_recurrent_gate)
314
+
315
+ # Gate the input.
316
+ gated_inputs = activations * input_gate
317
+
318
+ # Apply gamma normalization to the input. We need to clip the derivatives of
319
+ # `sqrt` in order to prevent NaNs during training in bfloat16. TODO a bit annoying
320
+ multiplier = 1
321
+ tracing = isinstance(activations, torch.fx.Proxy) or is_torchdynamo_compiling()
322
+ if not torch.jit.is_tracing() and not tracing:
323
+ multiplier = SqrtBoundDerivative.apply(1 - a_square)
324
+ multiplier = reset + ~reset * multiplier
325
+ normalized_x = gated_inputs * multiplier.type(activations.dtype)
326
+
327
+ hidden_states, recurrent_states = self._rnn_scan(
328
+ hidden_states=normalized_x,
329
+ recurrent_gate=recurrent_gate,
330
+ reset=reset,
331
+ recurrent_states=self.recurrent_states,
332
+ )
333
+ self.recurrent_states = recurrent_states
334
+ return hidden_states
335
+
336
+ # TODO refactor
337
+ def _rnn_scan(
338
+ self,
339
+ hidden_states: torch.Tensor,
340
+ recurrent_gate: torch.Tensor,
341
+ reset: torch.Tensor,
342
+ recurrent_states: Union[torch.Tensor, None],
343
+ acc_dtype: torch.dtype = torch.float32,
344
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
345
+ """Runs the recurrence of a linear RNN.
346
+
347
+ Args:
348
+ hidden_states: The input sequence.
349
+ recurrent_gate: The diagonal of the recurrence matrix `A`.
350
+ reset: Indicator of document boundaries, e.g. when to reset the hidden state
351
+ of the RNN.
352
+ recurrent_states: The initial hidden state.
353
+ acc_dtype: The data type for the accumulation.
354
+
355
+ Returns:
356
+ The output of the linear recurrence.
357
+ """
358
+ # Multiply `a` by the reset.
359
+ recurrent_gate = recurrent_gate * ~reset
360
+
361
+ if hidden_states.shape[1] == 1:
362
+ # Using scan in sampling mode.
363
+ if recurrent_states is None: # same here, when decoding you always have cache
364
+ return hidden_states, hidden_states[:, 0].type(acc_dtype)
365
+
366
+ else:
367
+ contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to(
368
+ recurrent_gate.device
369
+ )
370
+ contextualized_states += hidden_states.type(acc_dtype)
371
+ return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1]
372
+
373
+ else:
374
+ # Using scan in linear mode.
375
+ if recurrent_states is None:
376
+ recurrent_states = torch.zeros(hidden_states[:, 0].shape, dtype=acc_dtype, device=hidden_states.device)
377
+
378
+ contextualized_states = torch.zeros_like(hidden_states)
379
+ for t in range(hidden_states.shape[1]):
380
+ recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device)
381
+ recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype)
382
+ contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype)
383
+
384
+ return contextualized_states, recurrent_states
385
+
386
+ class GriffinRecurrentblock(nn.Module):
387
+ """Griffin and Hawk's recurrent block."""
388
+
389
+ def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None):
390
+ super().__init__()
391
+ self.lru_width = config.lru_width
392
+ self.hidden_size = config.hidden_size
393
+ self.linear_y = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width)
394
+ self.linear_x = nn.Linear(in_features=config.hidden_size, out_features=config.lru_width)
395
+ self.linear_out = nn.Linear(in_features=config.lru_width, out_features=config.hidden_size)
396
+ self.conv1d_width = config.conv1d_width
397
+ self.conv_1d = nn.Conv1d(
398
+ config.lru_width,
399
+ config.lru_width,
400
+ kernel_size=config.conv1d_width,
401
+ groups=config.lru_width,
402
+ padding=config.conv1d_width - 1,
403
+ )
404
+ self.rg_lru = GriffinRglru(config)
405
+ self.act_fn = ACT2FN[config.hidden_activation]
406
+
407
+ self.conv1d_state = None
408
+
409
+ def forward(
410
+ self,
411
+ input_states: torch.Tensor,
412
+ position_ids: torch.Tensor,
413
+ attention_mask: torch.Tensor,
414
+ cache_position: torch.Tensor,
415
+ use_cache: bool = True,
416
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
417
+ _, seq_len, _ = input_states.shape
418
+
419
+ y_branch = self.linear_y(input_states)
420
+ y_branch = self.act_fn(y_branch)
421
+
422
+ x_branch = self.linear_x(input_states)
423
+ x_branch = x_branch.transpose(1, 2)
424
+
425
+ if use_cache:
426
+ if cache_position.shape[0] != 1: # prefill
427
+ self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0))
428
+ x_branch = self.conv_1d(x_branch)[..., :seq_len]
429
+ else: # decoding
430
+ conv_state = torch.cat((self.conv1d_state, x_branch), -1)
431
+ x_branch = torch.sum(conv_state * self.conv_1d.weight[:, 0, :], dim=-1) + self.conv_1d.bias
432
+ x_branch = x_branch.unsqueeze(-1)
433
+ self.conv1d_state = conv_state[:, :, 1:]
434
+ else:
435
+ x_branch = self.conv_1d(x_branch)[..., :seq_len]
436
+
437
+ x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids)
438
+
439
+ hidden_states = x_branch * y_branch
440
+ hidden_states = self.linear_out(hidden_states)
441
+ return hidden_states
442
+
443
+ def _setup_cache(self, batch, device, dtype):
444
+ # recurrent_states always computed in full precision
445
+ self.rg_lru.recurrent_states = torch.zeros((batch, self.lru_width), device=device, dtype=torch.float32)
446
+ self.conv1d_state = torch.zeros((batch, self.hidden_size, self.conv1d_width - 1), device=device, dtype=dtype)
447
+
448
+ def eager_attention_forward(
449
+ module: "ModernBergAttention",
450
+ qkv: torch.Tensor,
451
+ attention_mask: torch.Tensor,
452
+ sliding_window_mask: torch.Tensor,
453
+ position_ids: Optional[torch.LongTensor],
454
+ local_attention: Tuple[int, int],
455
+ bs: int,
456
+ dim: int,
457
+ output_attentions: Optional[bool] = False,
458
+ **_kwargs,
459
+ ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
460
+ # qkv: [batch_size, seqlen, 3, nheads, headdim]
461
+ cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
462
+ query, key, value = qkv.transpose(3, 1).unbind(dim=2)
463
+ # query, key, value: [batch_size, heads, seq_len, head_dim]
464
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
465
+
466
+ scale = module.head_dim**-0.5
467
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scale
468
+
469
+ if local_attention != (-1, -1):
470
+ attention_mask = sliding_window_mask
471
+
472
+ attn_weights = attn_weights + attention_mask
473
+
474
+ # upcast attention to fp32
475
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
476
+ attn_weights = nn.functional.dropout(attn_weights, p=module.attention_dropout, training=module.training)
477
+ attn_output = torch.matmul(attn_weights, value)
478
+ attn_output = attn_output.transpose(1, 2).contiguous()
479
+ attn_output = attn_output.view(bs, -1, dim)
480
+ if output_attentions:
481
+ return (attn_output, attn_weights)
482
+ return (attn_output,)
483
+
484
+ def flash_attention_forward(
485
+ module: "ModernBergAttention",
486
+ qkv: torch.Tensor,
487
+ rotary_emb: ModernBertUnpaddedRotaryEmbedding,
488
+ cu_seqlens: torch.Tensor,
489
+ max_seqlen: int,
490
+ local_attention: Tuple[int, int],
491
+ bs: int,
492
+ dim: int,
493
+ target_dtype: torch.dtype = torch.bfloat16,
494
+ **_kwargs,
495
+ ) -> Tuple[torch.Tensor]:
496
+ # (total_seqlen, 3, nheads, headdim)
497
+ qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
498
+
499
+ convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
500
+ if convert_dtype:
501
+ # FA2 implementation only supports fp16 and bf16. If FA2 is supported,
502
+ # bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
503
+ orig_dtype = qkv.dtype
504
+ qkv = qkv.to(target_dtype)
505
+
506
+ attn = flash_attn_varlen_qkvpacked_func(
507
+ qkv,
508
+ cu_seqlens=cu_seqlens,
509
+ max_seqlen=max_seqlen,
510
+ dropout_p=module.attention_dropout if module.training else 0.0,
511
+ deterministic=module.deterministic_flash_attn,
512
+ window_size=local_attention,
513
+ )
514
+ attn = attn.to(orig_dtype) # type: ignore
515
+ else:
516
+ attn = flash_attn_varlen_qkvpacked_func(
517
+ qkv,
518
+ cu_seqlens=cu_seqlens,
519
+ max_seqlen=max_seqlen,
520
+ dropout_p=module.attention_dropout if module.training else 0.0,
521
+ deterministic=module.deterministic_flash_attn,
522
+ window_size=local_attention,
523
+ )
524
+ return (attn.view(bs, dim),)
525
+
526
+ def sdpa_attention_forward(
527
+ module: "ModernBergAttention",
528
+ qkv: torch.Tensor,
529
+ attention_mask: torch.Tensor,
530
+ sliding_window_mask: torch.Tensor,
531
+ position_ids: Optional[torch.LongTensor],
532
+ local_attention: Tuple[int, int],
533
+ bs: int,
534
+ dim: int,
535
+ **_kwargs,
536
+ ) -> Tuple[torch.Tensor]:
537
+ # qkv: [batch_size, seqlen, 3, nheads, headdim]
538
+ cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
539
+ query, key, value = qkv.transpose(3, 1).unbind(dim=2)
540
+ # query, key, value: [batch_size, heads, seq_len, head_dim]
541
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
542
+
543
+ if local_attention != (-1, -1):
544
+ attention_mask = sliding_window_mask
545
+
546
+ attn_output = (
547
+ F.scaled_dot_product_attention(
548
+ query,
549
+ key,
550
+ value,
551
+ dropout_p=module.attention_dropout if module.training else 0.0,
552
+ attn_mask=attention_mask,
553
+ )
554
+ .transpose(1, 2)
555
+ .contiguous()
556
+ )
557
+ attn_output = attn_output.view(bs, -1, dim)
558
+ return (attn_output,)
559
+
560
+ MODERNBERT_ATTENTION_FUNCTION = {
561
+ "flash_attention_2": flash_attention_forward,
562
+ "eager": eager_attention_forward,
563
+ "sdpa": sdpa_attention_forward,
564
+ }
565
+
566
+ class ModernBergRotaryEmbedding(GemmaRotaryEmbedding):
567
+ def __init__(self, config: ModernBergConfig, dim: int, base: float, device: Optional[torch.device] = None):
568
+ # JANK!!! JANK!!! JANK!!!
569
+ config.rope_theta = base
570
+ super().__init__(config=config, device=device)
571
+ inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base)
572
+
573
+ class ModernBergAttention(nn.Module):
574
+ """Performs multi-headed self attention on a batch of unpadded sequences.
575
+
576
+ If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
577
+ If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
578
+ which requires padding and unpadding inputs, adding some overhead.
579
+
580
+ See `forward` method for additional details.
581
+ """
582
+
583
+ def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None):
584
+ super().__init__()
585
+ self.config = config
586
+ self.layer_id = layer_id
587
+
588
+ if config.hidden_size % config.num_attention_heads != 0:
589
+ raise ValueError(
590
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
591
+ )
592
+
593
+ self.attention_dropout = config.attention_dropout
594
+ self.deterministic_flash_attn = config.deterministic_flash_attn
595
+ self.num_heads = config.num_attention_heads
596
+ self.head_dim = config.hidden_size // config.num_attention_heads
597
+ self.all_head_size = self.head_dim * self.num_heads
598
+ self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias)
599
+
600
+ assert layer_id % config.global_temporal_every_n_layers != 0, "ModernBerg does not support global self-attention"
601
+ self.local_attention = (config.local_attention // 2, config.local_attention // 2)
602
+
603
+ rope_theta = config.global_rope_theta
604
+ max_position_embeddings = config.max_position_embeddings
605
+ if self.local_attention != (-1, -1):
606
+ if config.local_rope_theta is not None:
607
+ rope_theta = config.local_rope_theta
608
+ max_position_embeddings = config.local_attention
609
+
610
+ if config._attn_implementation == "flash_attention_2":
611
+ self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(
612
+ dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
613
+ )
614
+ else:
615
+ self.rotary_emb = ModernBergRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)
616
+
617
+ self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
618
+ self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()
619
+ self.pruned_heads = set()
620
+
621
+ def forward(
622
+ self,
623
+ hidden_states: torch.Tensor,
624
+ output_attentions: Optional[bool] = False,
625
+ **kwargs,
626
+ ) -> torch.Tensor:
627
+ qkv = self.Wqkv(hidden_states)
628
+
629
+ bs = hidden_states.shape[0]
630
+ if self.config._attn_implementation == "flash_attention_2":
631
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
632
+ else:
633
+ qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim)
634
+
635
+ attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
636
+ self,
637
+ qkv=qkv,
638
+ rotary_emb=self.rotary_emb,
639
+ local_attention=self.local_attention,
640
+ bs=bs,
641
+ dim=self.all_head_size,
642
+ output_attentions=output_attentions,
643
+ **kwargs,
644
+ )
645
+ hidden_states = attn_outputs[0]
646
+ hidden_states = self.out_drop(self.Wo(hidden_states))
647
+
648
+ return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
649
+
650
+ class ModernBergTemporalLayer(nn.Module):
651
+ def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None):
652
+ super().__init__()
653
+ self.config = config
654
+ if layer_id % config.global_temporal_every_n_layers == 0:
655
+ self.temporal = GriffinRecurrentblock(config=config, layer_id=layer_id)
656
+ else:
657
+ self.temporal = ModernBergAttention(config=config, layer_id=layer_id)
658
+
659
+ def forward(self, hidden_states: torch.Tensor, **kwargs):
660
+ return self.temporal(hidden_states, **kwargs)
661
+
662
+ class ModernBergEncoderLayer(nn.Module):
663
+ def __init__(self, config: ModernBergConfig, layer_id: Optional[int] = None):
664
+ super().__init__()
665
+ self.config = config
666
+ if layer_id == 0:
667
+ self.temporal_norm = nn.Identity()
668
+ else:
669
+ self.temporal_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
670
+ self.temporal = ModernBergTemporalLayer(config=config, layer_id=layer_id)
671
+ self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
672
+ self.mlp = ModernBertMLP(config)
673
+
674
+ @torch.compile(dynamic=True)
675
+ def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor:
676
+ return self.mlp(self.mlp_norm(hidden_states))
677
+
678
+ def forward(
679
+ self,
680
+ hidden_states: torch.Tensor,
681
+ attention_mask: Optional[torch.Tensor] = None,
682
+ sliding_window_mask: Optional[torch.Tensor] = None,
683
+ position_ids: Optional[torch.LongTensor] = None,
684
+ cu_seqlens: Optional[torch.Tensor] = None,
685
+ max_seqlen: Optional[int] = None,
686
+ output_attentions: Optional[bool] = False,
687
+ ) -> torch.Tensor:
688
+ attn_outputs = self.temporal(
689
+ self.temporal_norm(hidden_states),
690
+ attention_mask=attention_mask,
691
+ sliding_window_mask=sliding_window_mask,
692
+ position_ids=position_ids,
693
+ cu_seqlens=cu_seqlens,
694
+ max_seqlen=max_seqlen,
695
+ output_attentions=output_attentions,
696
+ )
697
+ hidden_states = hidden_states + attn_outputs[0]
698
+ mlp_output = (
699
+ self.compiled_mlp(hidden_states)
700
+ if self.config.reference_compile
701
+ else self.mlp(self.mlp_norm(hidden_states))
702
+ )
703
+ hidden_states = hidden_states + mlp_output
704
+
705
+ return (hidden_states,) + attn_outputs[1:] # add attentions if outputted
706
+
707
+ MODERNBERG_START_DOCSTRING = r"""
708
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
709
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
710
+ etc.)
711
+
712
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
713
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
714
+ and behavior.
715
+
716
+ Parameters:
717
+ config ([`ModernBergConfig`]):
718
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
719
+ load the weights associated with the model, only the configuration. Check out the
720
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
721
+ """
722
+
723
+
724
+ @add_start_docstrings(
725
+ "The bare ModernBerg Model outputting raw hidden-states without any specific head on top.",
726
+ MODERNBERG_START_DOCSTRING,
727
+ )
728
+ class ModernBergPreTrainedModel(PreTrainedModel):
729
+ config_class = ModernBergConfig
730
+ base_model_prefix = "model"
731
+ supports_gradient_checkpointing = True
732
+ _no_split_modules = ["ModernBergEmbeddings", "ModernBergEncoderLayer"]
733
+ _supports_flash_attn_2 = True
734
+ _supports_sdpa = True
735
+ _supports_flex_attn = False
736
+
737
+ def _init_weights(self, module: nn.Module):
738
+ cutoff_factor = self.config.initializer_cutoff_factor
739
+ if cutoff_factor is None:
740
+ cutoff_factor = 3
741
+
742
+ def init_weight(module: nn.Module, std: float):
743
+ nn.init.trunc_normal_(
744
+ module.weight,
745
+ mean=0.0,
746
+ std=std,
747
+ a=-cutoff_factor * std,
748
+ b=cutoff_factor * std,
749
+ )
750
+
751
+ if isinstance(module, nn.Linear):
752
+ if module.bias is not None:
753
+ nn.init.zeros_(module.bias)
754
+
755
+ stds = {
756
+ "in": self.config.initializer_range,
757
+ "out": self.config.initializer_range / math.sqrt(2.0 * self.config.num_hidden_layers),
758
+ "embedding": self.config.initializer_range,
759
+ "final_out": self.config.hidden_size**-0.5,
760
+ }
761
+
762
+ std = math.sqrt(self.config.initializer_range / self.config.conv1d_width)
763
+ if isinstance(module, ModernBertEmbeddings):
764
+ init_weight(module.tok_embeddings, stds["embedding"])
765
+ elif isinstance(module, ModernBertMLP):
766
+ init_weight(module.Wi, stds["in"])
767
+ init_weight(module.Wo, stds["out"])
768
+ elif isinstance(module, ModernBergAttention):
769
+ init_weight(module.Wqkv, stds["in"])
770
+ init_weight(module.Wo, stds["out"])
771
+ elif isinstance(module, GriffinRecurrentblock):
772
+ torch.nn.init.zeros_(module.linear_x.bias)
773
+ torch.nn.init.normal_(module.linear_x.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
774
+
775
+ torch.nn.init.zeros_(module.linear_y.bias)
776
+ torch.nn.init.normal_(module.linear_y.weight, mean=0.0, std=math.sqrt(1.0 / self.config.hidden_size))
777
+
778
+ std = math.sqrt(self.config.initializer_range / self.config.lru_width)
779
+ torch.nn.init.normal_(module.linear_out.weight, mean=0.0, std=std)
780
+ torch.nn.init.zeros_(module.linear_out.bias)
781
+ elif isinstance(module, GriffinRglru):
782
+ std = math.sqrt(
783
+ self.config.initializer_range / (self.config.lru_width // self.config.num_attention_heads)
784
+ )
785
+ torch.nn.init.normal_(module.input_gate_weight, mean=0.0, std=std)
786
+ torch.nn.init.normal_(module.recurrent_gate_weight, mean=0.0, std=std)
787
+ torch.nn.init.zeros_(module.input_gate_bias)
788
+ torch.nn.init.zeros_(module.recurrent_gate_bias)
789
+
790
+ module.recurrent_param.data.uniform_(0.9**2 + 1e-8, 0.999**2 + 1e-8)
791
+ module.recurrent_param.data.log_().mul_(0.5)
792
+ module.recurrent_param.data.neg_().exp_().sub_(1.0).log_()
793
+ elif isinstance(module, ModernBergPredictionHead):
794
+ init_weight(module.dense, stds["out"])
795
+ elif isinstance(module, ModernBergForMaskedLM):
796
+ init_weight(module.decoder, stds["out"])
797
+ elif isinstance(module, (ModernBergForSequenceClassification, ModernBergForTokenClassification)):
798
+ init_weight(module.classifier, stds["final_out"])
799
+ elif isinstance(module, nn.Linear):
800
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
801
+ if getattr(module, "bias", None) is not None:
802
+ torch.nn.init.zeros_(module.bias)
803
+
804
+ @classmethod
805
+ def _autoset_attn_implementation(
806
+ cls,
807
+ config,
808
+ use_flash_attention_2: bool = False,
809
+ torch_dtype: Optional[torch.dtype] = None,
810
+ device_map: Optional[Union[str, Dict[str, int]]] = None,
811
+ check_device_map: bool = True,
812
+ ):
813
+ # If the user didn't specify anything, try to use flash_attention_2 if available.
814
+ # Otherwise we fall back to the default SDPA -> Eager from the super() method.
815
+ # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
816
+ # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
817
+ if config._attn_implementation_internal is None:
818
+ config._attn_implementation_internal = "flash_attention_2"
819
+ try:
820
+ return cls._check_and_enable_flash_attn_2(
821
+ config,
822
+ torch_dtype=torch.float16,
823
+ device_map=device_map,
824
+ hard_check_only=False,
825
+ check_device_map=check_device_map,
826
+ )
827
+ except (ValueError, ImportError):
828
+ config._attn_implementation_internal = None
829
+ return super()._autoset_attn_implementation(
830
+ config,
831
+ use_flash_attention_2=use_flash_attention_2,
832
+ torch_dtype=torch.float16,
833
+ device_map=device_map,
834
+ check_device_map=check_device_map,
835
+ )
836
+
837
+ def _maybe_set_compile(self):
838
+ if self.config.reference_compile is False:
839
+ return
840
+
841
+ if hasattr(self, "hf_device_map") and len(self.hf_device_map) > 1:
842
+ if self.config.reference_compile:
843
+ logger.warning_once(
844
+ "If `accelerate` split the model across devices, `torch.compile` will not work. "
845
+ "Falling back to non-compiled mode."
846
+ )
847
+ self.config.reference_compile = False
848
+
849
+ if self.device.type == "mps":
850
+ if self.config.reference_compile:
851
+ logger.warning_once(
852
+ "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. "
853
+ "Falling back to non-compiled mode."
854
+ )
855
+ self.config.reference_compile = False
856
+
857
+ if self.device.type == "cpu":
858
+ if self.config.reference_compile:
859
+ logger.warning_once(
860
+ "Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
861
+ "Falling back to non-compiled mode."
862
+ )
863
+ self.config.reference_compile = False
864
+
865
+ if self.config.reference_compile is None:
866
+ self.config.reference_compile = is_triton_available()
867
+
868
+ def resize_token_embeddings(self, *args, **kwargs):
869
+ model_embeds = super().resize_token_embeddings(*args, **kwargs)
870
+
871
+ if self.config.reference_compile in {True, None}:
872
+ if self.config.reference_compile:
873
+ logger.warning_once(
874
+ "Resizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode."
875
+ )
876
+ self.config.reference_compile = False
877
+
878
+ return model_embeds
879
+
880
+
881
+ MODERNBERG_INPUTS_DOCSTRING = r"""
882
+ Args:
883
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
884
+ Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
885
+ by default should you provide it.
886
+
887
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
888
+ [`PreTrainedTokenizer.__call__`] for details.
889
+
890
+ [What are input IDs?](../glossary#input-ids)
891
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
892
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
893
+
894
+ - 1 for tokens that are **not masked**,
895
+ - 0 for tokens that are **masked**.
896
+
897
+ [What are attention masks?](../glossary#attention-mask)
898
+
899
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
900
+ [`PreTrainedTokenizer.__call__`] for details.
901
+
902
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
903
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
904
+ information on the default strategy.
905
+
906
+ - 1 indicates the head is **not masked**,
907
+ - 0 indicates the head is **masked**.
908
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
909
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
910
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
911
+ far-away tokens in the local attention layers when not using Flash Attention.
912
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
913
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
914
+ config.n_positions - 1]`.
915
+
916
+ [What are position IDs?](../glossary#position-ids)
917
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
918
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
919
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
920
+ model's internal embedding lookup matrix.
921
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
922
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
923
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
924
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
925
+ max_seqlen (`int`, *optional*):
926
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
927
+ batch_size (`int`, *optional*):
928
+ Batch size of the input sequences. Used to pad the output tensors.
929
+ seq_len (`int`, *optional*):
930
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
931
+ output_attentions (`bool`, *optional*):
932
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
933
+ tensors for more detail.
934
+ output_hidden_states (`bool`, *optional*):
935
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
936
+ more detail.
937
+ return_dict (`bool`, *optional*):
938
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
939
+ """
940
+
941
+
942
+ @add_start_docstrings(
943
+ "The bare ModernBerg Model outputting raw hidden-states without any specific head on top.",
944
+ MODERNBERG_START_DOCSTRING,
945
+ )
946
+ class ModernBergModel(ModernBergPreTrainedModel):
947
+ def __init__(self, config: ModernBergConfig):
948
+ super().__init__(config)
949
+ self.config = config
950
+ self.embeddings = ModernBertEmbeddings(config)
951
+ self.layers = nn.ModuleList(
952
+ [ModernBergEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
953
+ )
954
+ self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
955
+ self.gradient_checkpointing = False
956
+ self.post_init()
957
+
958
+ def get_input_embeddings(self):
959
+ return self.embeddings.tok_embeddings
960
+
961
+ def set_input_embeddings(self, value):
962
+ self.embeddings.tok_embeddings = value
963
+
964
+ @add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING)
965
+ @add_code_sample_docstrings(
966
+ checkpoint=_CHECKPOINT_FOR_DOC,
967
+ output_type=BaseModelOutput,
968
+ config_class=_CONFIG_FOR_DOC,
969
+ )
970
+ def forward(
971
+ self,
972
+ input_ids: Optional[torch.LongTensor] = None,
973
+ attention_mask: Optional[torch.Tensor] = None,
974
+ sliding_window_mask: Optional[torch.Tensor] = None,
975
+ position_ids: Optional[torch.LongTensor] = None,
976
+ inputs_embeds: Optional[torch.Tensor] = None,
977
+ indices: Optional[torch.Tensor] = None,
978
+ cu_seqlens: Optional[torch.Tensor] = None,
979
+ max_seqlen: Optional[int] = None,
980
+ batch_size: Optional[int] = None,
981
+ seq_len: Optional[int] = None,
982
+ output_attentions: Optional[bool] = None,
983
+ output_hidden_states: Optional[bool] = None,
984
+ return_dict: Optional[bool] = None,
985
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutput]:
986
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
987
+ output_hidden_states = (
988
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
989
+ )
990
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
991
+
992
+ if (input_ids is None) ^ (inputs_embeds is not None):
993
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
994
+
995
+ all_hidden_states = () if output_hidden_states else None
996
+ all_self_attentions = () if output_attentions else None
997
+
998
+ self._maybe_set_compile()
999
+
1000
+ if input_ids is not None:
1001
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1002
+
1003
+ if batch_size is None and seq_len is None:
1004
+ if inputs_embeds is not None:
1005
+ batch_size, seq_len = inputs_embeds.shape[:2]
1006
+ else:
1007
+ batch_size, seq_len = input_ids.shape[:2]
1008
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1009
+
1010
+ if attention_mask is None:
1011
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1012
+
1013
+ repad = False
1014
+ if self.config._attn_implementation == "flash_attention_2":
1015
+ if indices is None and cu_seqlens is None and max_seqlen is None:
1016
+ repad = True
1017
+ if inputs_embeds is None:
1018
+ with torch.no_grad():
1019
+ input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
1020
+ inputs=input_ids, attention_mask=attention_mask
1021
+ )
1022
+ else:
1023
+ inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
1024
+ inputs=inputs_embeds, attention_mask=attention_mask
1025
+ )
1026
+ else:
1027
+ if position_ids is None:
1028
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
1029
+
1030
+ attention_mask, sliding_window_mask = self._update_attention_mask(
1031
+ attention_mask, output_attentions=output_attentions
1032
+ )
1033
+
1034
+ hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
1035
+
1036
+ for encoder_layer in self.layers:
1037
+ if output_hidden_states:
1038
+ all_hidden_states = all_hidden_states + (hidden_states,)
1039
+
1040
+ if self.gradient_checkpointing and self.training:
1041
+ layer_outputs = self._gradient_checkpointing_func(
1042
+ encoder_layer.__call__,
1043
+ hidden_states,
1044
+ attention_mask,
1045
+ sliding_window_mask,
1046
+ position_ids,
1047
+ cu_seqlens,
1048
+ max_seqlen,
1049
+ output_attentions,
1050
+ )
1051
+ else:
1052
+ layer_outputs = encoder_layer(
1053
+ hidden_states,
1054
+ attention_mask=attention_mask,
1055
+ sliding_window_mask=sliding_window_mask,
1056
+ position_ids=position_ids,
1057
+ cu_seqlens=cu_seqlens,
1058
+ max_seqlen=max_seqlen,
1059
+ output_attentions=output_attentions,
1060
+ )
1061
+ hidden_states = layer_outputs[0]
1062
+ if output_attentions and len(layer_outputs) > 1:
1063
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1064
+
1065
+ if output_hidden_states:
1066
+ all_hidden_states = all_hidden_states + (hidden_states,)
1067
+
1068
+ hidden_states = self.final_norm(hidden_states)
1069
+
1070
+ if repad:
1071
+ hidden_states = _pad_modernbert_output(
1072
+ inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
1073
+ )
1074
+ if all_hidden_states is not None:
1075
+ all_hidden_states = tuple(
1076
+ _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
1077
+ for hs in all_hidden_states
1078
+ )
1079
+
1080
+ if not return_dict:
1081
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
1082
+ return BaseModelOutput(
1083
+ last_hidden_state=hidden_states,
1084
+ hidden_states=all_hidden_states,
1085
+ attentions=all_self_attentions,
1086
+ )
1087
+
1088
+ def _update_attention_mask(self, attention_mask: torch.Tensor, output_attentions: bool) -> torch.Tensor:
1089
+ if output_attentions:
1090
+ if self.config._attn_implementation == "sdpa":
1091
+ logger.warning_once(
1092
+ "Outputting attentions is only supported with the 'eager' attention implementation, "
1093
+ 'not with "sdpa". Falling back to `attn_implementation="eager"`.'
1094
+ )
1095
+ self.config._attn_implementation = "eager"
1096
+ elif self.config._attn_implementation != "eager":
1097
+ logger.warning_once(
1098
+ "Outputting attentions is only supported with the eager attention implementation, "
1099
+ f'not with {self.config._attn_implementation}. Consider setting `attn_implementation="eager"`.'
1100
+ " Setting `output_attentions=False`."
1101
+ )
1102
+
1103
+ global_attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
1104
+
1105
+ # Create position indices
1106
+ rows = torch.arange(global_attention_mask.shape[2]).unsqueeze(0)
1107
+ # Calculate distance between positions
1108
+ distance = torch.abs(rows - rows.T)
1109
+
1110
+ # Create sliding window mask (1 for positions within window, 0 outside)
1111
+ window_mask = (
1112
+ (distance <= self.config.local_attention // 2).unsqueeze(0).unsqueeze(0).to(attention_mask.device)
1113
+ )
1114
+ # Combine with existing mask
1115
+ sliding_window_mask = global_attention_mask.masked_fill(window_mask.logical_not(), torch.finfo(self.dtype).min)
1116
+
1117
+ return global_attention_mask, sliding_window_mask
1118
+
1119
+
1120
+ class ModernBergPredictionHead(nn.Module):
1121
+ def __init__(self, config: ModernBergConfig):
1122
+ super().__init__()
1123
+ self.config = config
1124
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
1125
+ self.act = ACT2FN[config.classifier_activation]
1126
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
1127
+
1128
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1129
+ return self.norm(self.act(self.dense(hidden_states)))
1130
+
1131
+
1132
+ @add_start_docstrings(
1133
+ "The ModernBerg Model with a decoder head on top that is used for masked language modeling.",
1134
+ MODERNBERG_START_DOCSTRING,
1135
+ )
1136
+ class ModernBergForMaskedLM(ModernBergPreTrainedModel):
1137
+ _tied_weights_keys = ["decoder.weight"]
1138
+
1139
+ def __init__(self, config: ModernBergConfig):
1140
+ super().__init__(config)
1141
+ self.config = config
1142
+ self.model = ModernBergModel(config)
1143
+ self.head = ModernBergPredictionHead(config)
1144
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
1145
+
1146
+ self.sparse_prediction = self.config.sparse_prediction
1147
+ self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
1148
+
1149
+ # Initialize weights and apply final processing
1150
+ self.post_init()
1151
+
1152
+ def get_output_embeddings(self):
1153
+ return self.decoder
1154
+
1155
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1156
+ self.decoder = new_embeddings
1157
+
1158
+ @torch.compile(dynamic=True)
1159
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1160
+ return self.decoder(self.head(output))
1161
+
1162
+ @add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING)
1163
+ @add_code_sample_docstrings(
1164
+ checkpoint=_CHECKPOINT_FOR_DOC,
1165
+ output_type=MaskedLMOutput,
1166
+ config_class=_CONFIG_FOR_DOC,
1167
+ )
1168
+ def forward(
1169
+ self,
1170
+ input_ids: Optional[torch.LongTensor] = None,
1171
+ attention_mask: Optional[torch.Tensor] = None,
1172
+ sliding_window_mask: Optional[torch.Tensor] = None,
1173
+ position_ids: Optional[torch.Tensor] = None,
1174
+ inputs_embeds: Optional[torch.Tensor] = None,
1175
+ labels: Optional[torch.Tensor] = None,
1176
+ indices: Optional[torch.Tensor] = None,
1177
+ cu_seqlens: Optional[torch.Tensor] = None,
1178
+ max_seqlen: Optional[int] = None,
1179
+ batch_size: Optional[int] = None,
1180
+ seq_len: Optional[int] = None,
1181
+ output_attentions: Optional[bool] = None,
1182
+ output_hidden_states: Optional[bool] = None,
1183
+ return_dict: Optional[bool] = None,
1184
+ **kwargs,
1185
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1186
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1187
+ self._maybe_set_compile()
1188
+
1189
+ if self.config._attn_implementation == "flash_attention_2":
1190
+ if indices is None and cu_seqlens is None and max_seqlen is None:
1191
+ if batch_size is None and seq_len is None:
1192
+ if inputs_embeds is not None:
1193
+ batch_size, seq_len = inputs_embeds.shape[:2]
1194
+ else:
1195
+ batch_size, seq_len = input_ids.shape[:2]
1196
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1197
+
1198
+ if attention_mask is None:
1199
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1200
+
1201
+ if inputs_embeds is None:
1202
+ with torch.no_grad():
1203
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1204
+ inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1205
+ )
1206
+ else:
1207
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
1208
+ inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1209
+ )
1210
+
1211
+ outputs = self.model(
1212
+ input_ids=input_ids,
1213
+ attention_mask=attention_mask,
1214
+ sliding_window_mask=sliding_window_mask,
1215
+ position_ids=position_ids,
1216
+ inputs_embeds=inputs_embeds,
1217
+ indices=indices,
1218
+ cu_seqlens=cu_seqlens,
1219
+ max_seqlen=max_seqlen,
1220
+ batch_size=batch_size,
1221
+ seq_len=seq_len,
1222
+ output_attentions=output_attentions,
1223
+ output_hidden_states=output_hidden_states,
1224
+ return_dict=return_dict,
1225
+ )
1226
+ last_hidden_state = outputs[0]
1227
+
1228
+ if self.sparse_prediction and labels is not None:
1229
+ # flatten labels and output first
1230
+ labels = labels.view(-1)
1231
+ last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
1232
+
1233
+ # then filter out the non-masked tokens
1234
+ mask_tokens = labels != self.sparse_pred_ignore_index
1235
+ last_hidden_state = last_hidden_state[mask_tokens]
1236
+ labels = labels[mask_tokens]
1237
+
1238
+ logits = (
1239
+ self.compiled_head(last_hidden_state)
1240
+ if self.config.reference_compile
1241
+ else self.decoder(self.head(last_hidden_state))
1242
+ )
1243
+
1244
+ loss = None
1245
+ if labels is not None:
1246
+ loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
1247
+
1248
+ if self.config._attn_implementation == "flash_attention_2":
1249
+ with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
1250
+ logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1251
+
1252
+ if not return_dict:
1253
+ output = (logits,)
1254
+ return ((loss,) + output) if loss is not None else output
1255
+
1256
+ return MaskedLMOutput(
1257
+ loss=loss,
1258
+ logits=logits,
1259
+ hidden_states=outputs.hidden_states,
1260
+ attentions=outputs.attentions,
1261
+ )
1262
+
1263
+
1264
+ @add_start_docstrings(
1265
+ "The ModernBerg Model with a sequence classification head on top that performs pooling.",
1266
+ MODERNBERG_START_DOCSTRING,
1267
+ )
1268
+ class ModernBergForSequenceClassification(ModernBergPreTrainedModel):
1269
+ def __init__(self, config: ModernBergConfig):
1270
+ super().__init__(config)
1271
+ self.num_labels = config.num_labels
1272
+ self.config = config
1273
+
1274
+ self.model = ModernBergModel(config)
1275
+ self.head = ModernBergPredictionHead(config)
1276
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
1277
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1278
+
1279
+ # Initialize weights and apply final processing
1280
+ self.post_init()
1281
+
1282
+ @add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING)
1283
+ @add_code_sample_docstrings(
1284
+ checkpoint=_CHECKPOINT_FOR_DOC,
1285
+ output_type=SequenceClassifierOutput,
1286
+ config_class=_CONFIG_FOR_DOC,
1287
+ )
1288
+ def forward(
1289
+ self,
1290
+ input_ids: Optional[torch.LongTensor] = None,
1291
+ attention_mask: Optional[torch.Tensor] = None,
1292
+ sliding_window_mask: Optional[torch.Tensor] = None,
1293
+ position_ids: Optional[torch.Tensor] = None,
1294
+ inputs_embeds: Optional[torch.Tensor] = None,
1295
+ labels: Optional[torch.Tensor] = None,
1296
+ indices: Optional[torch.Tensor] = None,
1297
+ cu_seqlens: Optional[torch.Tensor] = None,
1298
+ max_seqlen: Optional[int] = None,
1299
+ batch_size: Optional[int] = None,
1300
+ seq_len: Optional[int] = None,
1301
+ output_attentions: Optional[bool] = None,
1302
+ output_hidden_states: Optional[bool] = None,
1303
+ return_dict: Optional[bool] = None,
1304
+ **kwargs,
1305
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1306
+ r"""
1307
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1308
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1309
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1310
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1311
+ """
1312
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1313
+ self._maybe_set_compile()
1314
+
1315
+ outputs = self.model(
1316
+ input_ids=input_ids,
1317
+ attention_mask=attention_mask,
1318
+ sliding_window_mask=sliding_window_mask,
1319
+ position_ids=position_ids,
1320
+ inputs_embeds=inputs_embeds,
1321
+ indices=indices,
1322
+ cu_seqlens=cu_seqlens,
1323
+ max_seqlen=max_seqlen,
1324
+ batch_size=batch_size,
1325
+ seq_len=seq_len,
1326
+ output_attentions=output_attentions,
1327
+ output_hidden_states=output_hidden_states,
1328
+ return_dict=return_dict,
1329
+ )
1330
+ last_hidden_state = outputs[0]
1331
+
1332
+ if self.config.classifier_pooling == "cls":
1333
+ last_hidden_state = last_hidden_state[:, 0]
1334
+ elif self.config.classifier_pooling == "mean":
1335
+ last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
1336
+ dim=1, keepdim=True
1337
+ )
1338
+
1339
+ pooled_output = self.head(last_hidden_state)
1340
+ pooled_output = self.drop(pooled_output)
1341
+ logits = self.classifier(pooled_output)
1342
+
1343
+ loss = None
1344
+ if labels is not None:
1345
+ if self.config.problem_type is None:
1346
+ if self.num_labels == 1:
1347
+ self.config.problem_type = "regression"
1348
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1349
+ self.config.problem_type = "single_label_classification"
1350
+ else:
1351
+ self.config.problem_type = "multi_label_classification"
1352
+
1353
+ if self.config.problem_type == "regression":
1354
+ loss_fct = MSELoss()
1355
+ if self.num_labels == 1:
1356
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1357
+ else:
1358
+ loss = loss_fct(logits, labels)
1359
+ elif self.config.problem_type == "single_label_classification":
1360
+ loss_fct = CrossEntropyLoss()
1361
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1362
+ elif self.config.problem_type == "multi_label_classification":
1363
+ loss_fct = BCEWithLogitsLoss()
1364
+ loss = loss_fct(logits, labels)
1365
+
1366
+ if not return_dict:
1367
+ output = (logits,)
1368
+ return ((loss,) + output) if loss is not None else output
1369
+
1370
+ return SequenceClassifierOutput(
1371
+ loss=loss,
1372
+ logits=logits,
1373
+ hidden_states=outputs.hidden_states,
1374
+ attentions=outputs.attentions,
1375
+ )
1376
+
1377
+
1378
+ @add_start_docstrings(
1379
+ "The ModernBerg Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.",
1380
+ MODERNBERG_START_DOCSTRING,
1381
+ )
1382
+ class ModernBergForTokenClassification(ModernBergPreTrainedModel):
1383
+ def __init__(self, config: ModernBergConfig):
1384
+ super().__init__(config)
1385
+ self.num_labels = config.num_labels
1386
+
1387
+ self.model = ModernBergModel(config)
1388
+ self.head = ModernBergPredictionHead(config)
1389
+ self.drop = torch.nn.Dropout(config.classifier_dropout)
1390
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1391
+
1392
+ # Initialize weights and apply final processing
1393
+ self.post_init()
1394
+
1395
+ @add_start_docstrings_to_model_forward(MODERNBERG_INPUTS_DOCSTRING)
1396
+ @add_code_sample_docstrings(
1397
+ checkpoint=_CHECKPOINT_FOR_DOC,
1398
+ output_type=TokenClassifierOutput,
1399
+ config_class=_CONFIG_FOR_DOC,
1400
+ )
1401
+ def forward(
1402
+ self,
1403
+ input_ids: Optional[torch.LongTensor] = None,
1404
+ attention_mask: Optional[torch.Tensor] = None,
1405
+ sliding_window_mask: Optional[torch.Tensor] = None,
1406
+ position_ids: Optional[torch.Tensor] = None,
1407
+ inputs_embeds: Optional[torch.Tensor] = None,
1408
+ labels: Optional[torch.Tensor] = None,
1409
+ indices: Optional[torch.Tensor] = None,
1410
+ cu_seqlens: Optional[torch.Tensor] = None,
1411
+ max_seqlen: Optional[int] = None,
1412
+ batch_size: Optional[int] = None,
1413
+ seq_len: Optional[int] = None,
1414
+ output_attentions: Optional[bool] = None,
1415
+ output_hidden_states: Optional[bool] = None,
1416
+ return_dict: Optional[bool] = None,
1417
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1418
+ r"""
1419
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1420
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1421
+ """
1422
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1423
+ self._maybe_set_compile()
1424
+
1425
+ outputs = self.model(
1426
+ input_ids=input_ids,
1427
+ attention_mask=attention_mask,
1428
+ sliding_window_mask=sliding_window_mask,
1429
+ position_ids=position_ids,
1430
+ inputs_embeds=inputs_embeds,
1431
+ indices=indices,
1432
+ cu_seqlens=cu_seqlens,
1433
+ max_seqlen=max_seqlen,
1434
+ batch_size=batch_size,
1435
+ seq_len=seq_len,
1436
+ output_attentions=output_attentions,
1437
+ output_hidden_states=output_hidden_states,
1438
+ return_dict=return_dict,
1439
+ )
1440
+ last_hidden_state = outputs[0]
1441
+
1442
+ last_hidden_state = self.head(last_hidden_state)
1443
+ last_hidden_state = self.drop(last_hidden_state)
1444
+ logits = self.classifier(last_hidden_state)
1445
+
1446
+ loss = None
1447
+ if labels is not None:
1448
+ loss_fct = CrossEntropyLoss()
1449
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1450
+
1451
+ if not return_dict:
1452
+ output = (logits,) + outputs[1:]
1453
+ return ((loss,) + output) if loss is not None else output
1454
+
1455
+ return TokenClassifierOutput(
1456
+ loss=loss,
1457
+ logits=logits,
1458
+ hidden_states=outputs.hidden_states,
1459
+ attentions=outputs.attentions,
1460
+ )
1461
+
1462
+
1463
+ __all__ = [
1464
+ "ModernBergConfig",
1465
+ "ModernBergModel",
1466
+ "ModernBergPreTrainedModel",
1467
+ "ModernBergForMaskedLM",
1468
+ "ModernBergForSequenceClassification",
1469
+ "ModernBergForTokenClassification",
1470
+ ]