xiulinyang commited on
Commit
7afc50c
·
verified ·
1 Parent(s): 1d00b43

add remote code + model files

Browse files
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # for HF remote code
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (513 Bytes). View file
 
__pycache__/configuration_alibi.cpython-310.pyc ADDED
Binary file (2.01 kB). View file
 
__pycache__/modeling_alibi.cpython-310.pyc ADDED
Binary file (15.1 kB). View file
 
configuration_alibi.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class AlibiConfig(PretrainedConfig):
9
+
10
+ model_type = 'transformer-project_fox'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size: int = 32000,
16
+ hidden_size: int = 2048,
17
+ hidden_ratio: Optional[int] = 4,
18
+ intermediate_size: Optional[int] = None,
19
+ num_hidden_layers: int = 24,
20
+ num_heads: int = 32,
21
+ num_kv_heads: int = None,
22
+ hidden_act: str = "swish",
23
+ window_size: Optional[int] = None,
24
+ max_position_embeddings: int = 2048,
25
+ initializer_range: float = 0.02,
26
+ elementwise_affine: Optional[bool] = True,
27
+ norm_eps: float = 1e-6,
28
+ use_cache: bool = True,
29
+ pad_token_id: int = None,
30
+ bos_token_id: int = 1,
31
+ eos_token_id: int = 2,
32
+ tie_word_embeddings: bool = False,
33
+ attention_bias: bool = False,
34
+ fuse_norm: bool = True,
35
+ fuse_cross_entropy: bool = True,
36
+ rope_base: float = 500000.0,
37
+ use_rope: bool = False,
38
+ use_alibi: bool = True,
39
+ **kwargs,
40
+ ):
41
+ self.vocab_size = vocab_size
42
+ self.hidden_size = hidden_size
43
+ self.hidden_ratio = hidden_ratio
44
+ self.intermediate_size = intermediate_size
45
+ self.num_hidden_layers = num_hidden_layers
46
+ self.num_heads = num_heads
47
+ self.num_kv_heads = num_kv_heads
48
+ self.window_size = window_size
49
+ self.max_position_embeddings = max_position_embeddings
50
+
51
+ self.hidden_act = hidden_act
52
+ self.initializer_range = initializer_range
53
+ self.elementwise_affine = elementwise_affine
54
+ self.norm_eps = norm_eps
55
+ self.use_cache = use_cache
56
+ self.attention_bias = attention_bias
57
+ self.fuse_cross_entropy = fuse_cross_entropy
58
+ self.fuse_norm = fuse_norm
59
+ self.rope_base = rope_base
60
+ self.use_rope = use_rope
61
+ self.use_alibi = use_alibi
62
+
63
+ super().__init__(
64
+ pad_token_id=pad_token_id,
65
+ bos_token_id=bos_token_id,
66
+ eos_token_id=eos_token_id,
67
+ tie_word_embeddings=tie_word_embeddings,
68
+ **kwargs,
69
+ )
modeling_alibi.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from fla.modules import FusedCrossEntropyLoss, RMSNorm,RotaryEmbedding
13
+ from jedi.inference.lazy_value import AbstractLazyValue
14
+ from torch.nn import functional as F
15
+ from fla.modules.activations import swiglu_linear
16
+ from transformers.activations import ACT2FN
17
+ from transformers.cache_utils import Cache, DynamicCache
18
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast)
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers.utils import logging
22
+ from einops import rearrange
23
+
24
+ from forgetting_transformer.model.alibi.configuration_alibi import AlibiConfig
25
+
26
+ from functools import partial
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ class Attention(nn.Module):
31
+
32
+ def __init__(
33
+ self,
34
+ hidden_size: int = 2048,
35
+ num_heads: int = 32,
36
+ num_kv_heads: Optional[int] = None,
37
+ window_size: Optional[int] = None,
38
+ max_position_embeddings: Optional[int] = None,
39
+ rope_base: float = 500000.0,
40
+ use_rope: bool = False,
41
+ use_alibi: bool = True,
42
+ layer_idx: int = None,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.num_heads = num_heads
47
+ if num_kv_heads is None:
48
+ self.num_kv_heads = self.num_heads
49
+ else:
50
+ self.num_kv_heads = num_kv_heads
51
+ self.num_kv_groups = num_heads // self.num_kv_heads
52
+ self.hidden_size = hidden_size
53
+ self.head_dim = self.hidden_size // self.num_heads
54
+ self.kv_dim = self.num_kv_heads * self.head_dim
55
+ self.window_size = window_size
56
+ self.max_position_embeddings = max_position_embeddings
57
+ self.layer_idx = layer_idx
58
+
59
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
60
+ self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
61
+ self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
62
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
63
+
64
+ if use_rope:
65
+ self.rotary = RotaryEmbedding(self.head_dim, base=rope_base)
66
+ else:
67
+ self.rotary = None
68
+
69
+ if use_alibi:
70
+ slopes = torch.tensor(self._get_slopes(self.num_heads), dtype=torch.float32)
71
+ self.register_buffer("alibi_slopes", slopes.view(1, -1, 1, 1), persistent=False)
72
+
73
+ self.apply(self._initialize_weights)
74
+
75
+ def _initialize_weights(self, module: nn.Module):
76
+ pass
77
+
78
+ def forward(
79
+ self,
80
+ hidden_states: torch.Tensor,
81
+ attention_mask: Optional[torch.LongTensor] = None,
82
+ past_key_values: Optional[Cache] = None,
83
+ output_attentions: bool = False,
84
+ use_cache: bool = False,
85
+ **kwargs,
86
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
87
+
88
+ B, T, _ = hidden_states.size()
89
+ q = rearrange(self.q_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_heads)
90
+ k = rearrange(self.k_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_kv_heads)
91
+ v = rearrange(self.v_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_kv_heads)
92
+
93
+ seqlen_offset = 0
94
+ max_seqlen = q.shape[1]
95
+ if past_key_values is not None:
96
+ seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
97
+ max_seqlen = q.shape[1] + seqlen_offset
98
+ if self.max_position_embeddings is not None:
99
+ max_seqlen = max(max_seqlen, self.max_position_embeddings)
100
+
101
+ if self.rotary is not None:
102
+ q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
103
+
104
+ q = rearrange(q, 'b t h d -> b h t d')
105
+ k = rearrange(k, 'b t h d -> b h t d')
106
+ v = rearrange(v, 'b t h d -> b h t d')
107
+
108
+
109
+ if past_key_values is not None:
110
+ k, v = past_key_values.update(k, v, self.layer_idx)
111
+
112
+
113
+ if self.num_kv_groups > 1:
114
+ k = k.repeat_interleave(self.num_kv_groups, dim=1) # [B, H, Tk, D]
115
+ v = v.repeat_interleave(self.num_kv_groups, dim=1) # [B, H, Tk, D]
116
+
117
+ B, H, Tq, Dh = q.shape
118
+ Tk = k.size(2)
119
+
120
+ scale = 1.0 / math.sqrt(Dh)
121
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
122
+
123
+ pos_q = (seqlen_offset + torch.arange(Tq, device=scores.device))
124
+ pos_k = torch.arange(Tk, device=scores.device)
125
+ causal_mask = (pos_k.unsqueeze(0) > pos_q.unsqueeze(1)) # [Tq, Tk]
126
+ scores = scores.masked_fill(causal_mask.view(1, 1, Tq, Tk), float('-inf'))
127
+
128
+ if hasattr(self, "alibi_slopes"):
129
+
130
+ rel = (pos_q.unsqueeze(1) - pos_k.unsqueeze(0)).to(torch.float32) # [Tq, Tk]
131
+ alibi_bias = -self.alibi_slopes.to(scores.device) * rel.view(1, 1, Tq, Tk) # [1, H, Tq, Tk]
132
+ scores = scores + alibi_bias.to(scores.dtype)
133
+
134
+
135
+ if attention_mask is not None and attention_mask.shape[-1] == Tk:
136
+ pad_mask = (attention_mask == 0).view(B, 1, 1, Tk)
137
+ scores = scores.masked_fill(pad_mask, float('-inf'))
138
+
139
+ if self.window_size is not None:
140
+ past_too_far = (pos_k.view(1, Tk) < (pos_q.view(Tq, 1) - (self.window_size - 1)))
141
+ scores = scores.masked_fill(past_too_far.view(1, 1, Tq, Tk), float('-inf'))
142
+
143
+ attn = torch.softmax(scores, dim=-1) # [B, H, Tq, Tk]
144
+ o = torch.matmul(attn, v) # [B, H, Tq, Dh]
145
+ o = rearrange(o, 'b h t d -> b t (h d)') # [B, Tq, H*Dh] = [B, Tq, hidden_size]
146
+ o = self.o_proj(o)
147
+
148
+ attentions = attn if output_attentions else None
149
+ return o, attentions, past_key_values
150
+
151
+ def _get_slopes(self, n):
152
+ """
153
+ Get slopes for Alibi positional embedding
154
+ n : int = number of heads.
155
+ For best performance, restrict n to a power of 2.
156
+ """
157
+
158
+ def get_slopes_power_of_2(n):
159
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
160
+ ratio = start
161
+ return [start * ratio**i for i in range(n)]
162
+
163
+ if math.log2(n).is_integer():
164
+ return get_slopes_power_of_2(n)
165
+ else:
166
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
167
+ return (
168
+ get_slopes_power_of_2(closest_power_of_2)
169
+ + self._get_slopes(2 * closest_power_of_2)[0::2][
170
+ : n - closest_power_of_2
171
+ ]
172
+ )
173
+
174
+
175
+ class TransformerMLP(nn.Module):
176
+
177
+ def __init__(
178
+ self,
179
+ hidden_size: int,
180
+ hidden_ratio: Optional[int] = None,
181
+ intermediate_size: Optional[int] = None,
182
+ hidden_act: str = 'swish'
183
+ ) -> TransformerMLP:
184
+ super().__init__()
185
+
186
+ self.hidden_size = hidden_size
187
+ # the final number of params is `hidden_ratio * hidden_size^2`
188
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
189
+ if hidden_ratio is None:
190
+ hidden_ratio = 4
191
+ if intermediate_size is None:
192
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
193
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
194
+ self.hidden_ratio = hidden_ratio
195
+ self.intermediate_size = intermediate_size
196
+
197
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
198
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
199
+ self.act_fn = ACT2FN[hidden_act]
200
+
201
+ def forward(self, x):
202
+ y = self.gate_proj(x)
203
+ gate, y = y.chunk(2, -1)
204
+ # TODO: maybe wrap swiglu_linear in custom_fwd/custom_bwd
205
+ return swiglu_linear(
206
+ gate, y,
207
+ self.down_proj.weight.to(y.dtype),
208
+ self.down_proj.bias.to(y.dtype) if self.down_proj.bias is not None else self.down_proj.bias
209
+ )
210
+
211
+
212
+ class TransformerBlock(nn.Module):
213
+ def __init__(self, config: AlibiConfig, layer_idx: int):
214
+ super().__init__()
215
+ self.hidden_size = config.hidden_size
216
+
217
+ self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
218
+ self.attn = Attention(
219
+ hidden_size=config.hidden_size,
220
+ num_heads=config.num_heads,
221
+ num_kv_heads=config.num_kv_heads,
222
+ window_size=config.window_size,
223
+ use_alibi=config.use_alibi,
224
+ max_position_embeddings=config.max_position_embeddings,
225
+ rope_base=config.rope_base,
226
+ use_rope=config.use_rope,
227
+ layer_idx=layer_idx
228
+ )
229
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
230
+ self.mlp = TransformerMLP(
231
+ hidden_size=config.hidden_size,
232
+ hidden_ratio=config.hidden_ratio,
233
+ intermediate_size=config.intermediate_size,
234
+ hidden_act=config.hidden_act
235
+ )
236
+
237
+ def forward_attn(
238
+ self,
239
+ hidden_states: torch.Tensor,
240
+ attention_mask: Optional[torch.Tensor] = None,
241
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
242
+ output_attentions: Optional[bool] = False,
243
+ use_cache: Optional[bool] = False,
244
+ **kwargs,
245
+ ):
246
+ # reisual handled outside
247
+ # residual = hidden_states
248
+ hidden_states = self.attn_norm(hidden_states)
249
+ hidden_states, attentions, past_key_values = self.attn(
250
+ hidden_states=hidden_states,
251
+ attention_mask=attention_mask,
252
+ past_key_values=past_key_values,
253
+ use_cache=use_cache,
254
+ output_attentions=output_attentions
255
+ )
256
+ return hidden_states, attentions, past_key_values
257
+
258
+ def forward_mlp(
259
+ self,
260
+ hidden_states: torch.Tensor,
261
+ residual: torch.Tensor,
262
+ ):
263
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
264
+ hidden_states = self.mlp(hidden_states)
265
+ hidden_states = residual + hidden_states
266
+
267
+ return hidden_states
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states: torch.Tensor,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
274
+ output_attentions: Optional[bool] = False,
275
+ use_cache: Optional[bool] = False,
276
+ gradient_checkpointing: bool = False
277
+ # **kwargs,
278
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
279
+
280
+ residual = hidden_states
281
+
282
+
283
+ if gradient_checkpointing:
284
+ forward_attn = partial(torch.utils.checkpoint.checkpoint, self.forward_attn, use_reentrant=False)
285
+ forward_mlp = partial(torch.utils.checkpoint.checkpoint, self.forward_mlp, use_reentrant=False)
286
+ else:
287
+ forward_attn = self.forward_attn
288
+ forward_mlp = self.forward_mlp
289
+
290
+ hidden_states, attentions, past_key_values = forward_attn(
291
+ hidden_states=hidden_states,
292
+ attention_mask=attention_mask,
293
+ past_key_values=past_key_values,
294
+ use_cache=use_cache,
295
+ output_attentions=output_attentions
296
+ )
297
+
298
+ hidden_states = forward_mlp(
299
+ hidden_states,
300
+ residual,
301
+ )
302
+
303
+ outputs = (hidden_states,)
304
+
305
+ if output_attentions:
306
+ outputs += (attentions,)
307
+
308
+ if use_cache:
309
+ outputs += (past_key_values,)
310
+
311
+ return outputs
312
+
313
+
314
+
315
+ class TransformerPreTrainedModel(PreTrainedModel):
316
+
317
+ config_class = AlibiConfig
318
+ supports_gradient_checkpointing = True
319
+ _no_split_modules = ['TransformerBlock']
320
+
321
+ def __init__(self, *inputs, **kwargs):
322
+ super().__init__(*inputs, **kwargs)
323
+
324
+ def _init_weights(
325
+ self,
326
+ module: nn.Module,
327
+ ):
328
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
329
+ # Slightly different from the TF version which uses truncated_normal for initialization
330
+ # cf https://github.com/pytorch/pytorch/pull/5617
331
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
332
+ if module.bias is not None:
333
+ nn.init.zeros_(module.bias)
334
+ elif isinstance(module, nn.Embedding):
335
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
336
+ if module.padding_idx is not None:
337
+ module.weight.data[module.padding_idx].zero_()
338
+
339
+
340
+ class AlibiModel(TransformerPreTrainedModel):
341
+
342
+ def __init__(self, config: AlibiConfig):
343
+ super().__init__(config)
344
+ self.padding_idx = config.pad_token_id
345
+ self.vocab_size = config.vocab_size
346
+
347
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
348
+ self.layers = nn.ModuleList([TransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
349
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
350
+
351
+ self.gradient_checkpointing = False
352
+
353
+ self.post_init()
354
+
355
+ def get_input_embeddings(self):
356
+ return self.embeddings
357
+
358
+ def set_input_embeddings(self, value):
359
+ self.embeddings = value
360
+
361
+ def forward(
362
+ self,
363
+ input_ids: Optional[torch.LongTensor] = None,
364
+ attention_mask: Optional[torch.Tensor] = None,
365
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
366
+ inputs_embeds: Optional[torch.FloatTensor] = None,
367
+ use_cache: Optional[bool] = None,
368
+ output_attentions: Optional[bool] = None,
369
+ output_hidden_states: Optional[bool] = None,
370
+ return_dict: Optional[bool] = None
371
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
372
+ if output_attentions:
373
+ warnings.warn(
374
+ "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
375
+ )
376
+ output_attentions = False
377
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
378
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
379
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
380
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
381
+
382
+ # retrieve input_ids and inputs_embeds
383
+ if input_ids is not None and inputs_embeds is not None:
384
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
385
+ elif input_ids is None and inputs_embeds is None:
386
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
387
+
388
+ if use_cache:
389
+ use_legacy_cache = not isinstance(past_key_values, Cache)
390
+ if use_legacy_cache:
391
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
392
+
393
+ if inputs_embeds is None:
394
+ inputs_embeds = self.embeddings(input_ids)
395
+
396
+ # embed positions
397
+ hidden_states = inputs_embeds
398
+
399
+ if self.gradient_checkpointing and self.training:
400
+ if use_cache:
401
+ logger.warning_once(
402
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
403
+ )
404
+ use_cache = False
405
+
406
+ all_hidden_states = () if output_hidden_states else None
407
+ all_attns = () if output_attentions else None
408
+ next_decoder_cache = None
409
+
410
+ for layer in self.layers:
411
+ if output_hidden_states:
412
+ all_hidden_states += (hidden_states,)
413
+
414
+ layer_outputs = layer(
415
+ hidden_states,
416
+ attention_mask=attention_mask,
417
+ past_key_values=past_key_values,
418
+ output_attentions=output_attentions,
419
+ use_cache=use_cache,
420
+ gradient_checkpointing=self.gradient_checkpointing and self.training
421
+ )
422
+
423
+ hidden_states = layer_outputs[0]
424
+
425
+ if use_cache:
426
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
427
+
428
+ if output_attentions:
429
+ all_attns += (layer_outputs[1],)
430
+
431
+ hidden_states = self.norm(hidden_states)
432
+
433
+ # add hidden states from the last decoder layer
434
+ if output_hidden_states:
435
+ all_hidden_states += (hidden_states,)
436
+
437
+ next_cache = None
438
+ if use_cache:
439
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
440
+ if not return_dict:
441
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
442
+
443
+ return BaseModelOutputWithPast(
444
+ last_hidden_state=hidden_states,
445
+ past_key_values=next_cache,
446
+ hidden_states=all_hidden_states,
447
+ attentions=all_attns
448
+ )
449
+
450
+
451
+ class AlibiForCausalLM(TransformerPreTrainedModel):
452
+ _tied_weights_keys = ["lm_head.weight"]
453
+
454
+ def __init__(self, config):
455
+ super().__init__(config)
456
+ self.model = AlibiModel(config)
457
+ self.vocab_size = config.vocab_size
458
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
459
+
460
+ # Initialize weights and apply final processing
461
+ self.post_init()
462
+
463
+ def get_input_embeddings(self):
464
+ return self.model.embeddings
465
+
466
+ def set_input_embeddings(self, value):
467
+ self.model.embeddings = value
468
+
469
+ def get_output_embeddings(self):
470
+ return self.lm_head
471
+
472
+ def set_output_embeddings(self, new_embeddings):
473
+ self.lm_head = new_embeddings
474
+
475
+ def set_decoder(self, decoder):
476
+ self.model = decoder
477
+
478
+ def get_decoder(self):
479
+ return self.model
480
+
481
+ def prepare_inputs_for_generation(
482
+ self,
483
+ input_ids: torch.LongTensor = None,
484
+ past_key_values: Optional[torch.Tensor] = None,
485
+ attention_mask: Optional[torch.Tensor] = None,
486
+ inputs_embeds: Optional[torch.Tensor] = None,
487
+ **kwargs
488
+ ):
489
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
490
+ if past_key_values is not None:
491
+ input_ids = input_ids[:, -1:]
492
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
493
+ if inputs_embeds is not None and past_key_values is None:
494
+ model_inputs = {'inputs_embeds': inputs_embeds}
495
+ else:
496
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
497
+ # recompiles graphs as the stride of the inputs is a guard.
498
+ # Ref: https://github.com/huggingface/transformers/pull/29114
499
+ # TODO: use `next_tokens` directly instead.
500
+ model_inputs = {'input_ids': input_ids.contiguous()}
501
+
502
+ model_inputs.update({
503
+ 'past_key_values': past_key_values,
504
+ 'use_cache': kwargs.get('use_cache'),
505
+ 'attention_mask': attention_mask,
506
+ })
507
+ return model_inputs
508
+
509
+ def forward(
510
+ self,
511
+ input_ids: torch.LongTensor = None,
512
+ attention_mask: Optional[torch.Tensor] = None,
513
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
514
+ inputs_embeds: Optional[torch.FloatTensor] = None,
515
+ labels: Optional[torch.LongTensor] = None,
516
+ use_cache: Optional[bool] = None,
517
+ output_attentions: Optional[bool] = None,
518
+ output_hidden_states: Optional[bool] = None,
519
+ return_dict: Optional[bool] = None,
520
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
521
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
522
+ output_hidden_states = (
523
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
524
+ )
525
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
526
+
527
+ outputs = self.model(
528
+ input_ids=input_ids,
529
+ attention_mask=attention_mask,
530
+ past_key_values=past_key_values,
531
+ inputs_embeds=inputs_embeds,
532
+ use_cache=use_cache,
533
+ output_attentions=output_attentions,
534
+ output_hidden_states=output_hidden_states,
535
+ return_dict=return_dict
536
+ )
537
+
538
+ hidden_states = outputs[0]
539
+
540
+ loss = None
541
+ if labels is not None:
542
+ if self.config.fuse_cross_entropy:
543
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True, reduction='none')
544
+ else:
545
+ loss_fct = nn.CrossEntropyLoss(reduction='none')
546
+ logits = self.lm_head(hidden_states)
547
+ # Enable model parallelism
548
+ labels = labels.to(logits.device)
549
+ # labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
550
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
551
+ loss = loss.view(*labels.size())
552
+ del logits
553
+ logits = None
554
+ else:
555
+ logits = self.lm_head(hidden_states)
556
+
557
+ if not return_dict:
558
+ output = (logits,) + outputs[1:]
559
+ return (loss,) + output if loss is not None else output
560
+
561
+ return CausalLMOutputWithPast(
562
+ loss=loss,
563
+ logits=logits,
564
+ past_key_values=outputs.past_key_values,
565
+ hidden_states=outputs.hidden_states,
566
+ attentions=outputs.attentions,
567
+ )