SmerkyG commited on
Commit
3f1e910
·
verified ·
1 Parent(s): 47905f4

Add files using upload-large-folder tool

Browse files
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RWKV6Qwen2ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_rwkv6qwen2.RWKV6Qwen2Config",
7
+ "AutoModelForCausalLM": "modeling_rwkv6qwen2.RWKV6Qwen2ForCausalLM"
8
+ },
9
+ "attention_bias": true,
10
+ "attention_dropout": 0.0,
11
+ "attention_output_bias": false,
12
+ "balance_state": true,
13
+ "bos_token_id": 151643,
14
+ "eos_token_id": 151643,
15
+ "gate_rank_type": 1,
16
+ "groupnorm_att": false,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 3584,
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 18944,
21
+ "lora_rank_decay": 96,
22
+ "lora_rank_tokenshift": 96,
23
+ "lora_rank_gate": 0,
24
+ "max_position_embeddings": 131072,
25
+ "max_window_layers": 28,
26
+ "model_type": "rwkv6qwen2",
27
+ "num_attention_heads": 28,
28
+ "num_hidden_layers": 28,
29
+ "num_key_value_heads": 4,
30
+ "rms_norm_eps": 1e-06,
31
+ "rope_theta": 1000000.0,
32
+ "sliding_window": 131072,
33
+ "tie_word_embeddings": false,
34
+ "torch_dtype": "bfloat16",
35
+ "transformers_version": "4.43.1",
36
+ "use_cache": true,
37
+ "use_rope": false,
38
+ "use_tokenshift": true,
39
+ "use_sliding_window": false,
40
+ "vocab_size": 152064
41
+ }
configuration_rwkv6qwen2.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RWKV6Qwen2 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class RWKV6Qwen2Config(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`RWKV6Qwen2Model`]. It is used to instantiate a
28
+ RWKV6Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of
30
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 151936):
38
+ Vocabulary size of the RWKV6Qwen2 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`RWKV6Qwen2Model`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 22016):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer encoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ num_key_value_heads (`int`, *optional*, defaults to 32):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
56
+ The non-linear activation function (function or string) in the decoder.
57
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
58
+ The maximum sequence length that this model might ever be used with.
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
62
+ The epsilon used by the rms normalization layers.
63
+ use_cache (`bool`, *optional*, defaults to `True`):
64
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
65
+ relevant if `config.is_decoder=True`.
66
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
67
+ Whether the model's input and output word embeddings should be tied.
68
+ rope_theta (`float`, *optional*, defaults to 10000.0):
69
+ The base period of the RoPE embeddings.
70
+ rope_scaling (`Dict`, *optional*):
71
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
72
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
73
+ accordingly.
74
+ Expected contents:
75
+ `rope_type` (`str`):
76
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
77
+ 'llama3'], with 'default' being the original RoPE implementation.
78
+ `factor` (`float`, *optional*):
79
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
80
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
81
+ original maximum pre-trained length.
82
+ `original_max_position_embeddings` (`int`, *optional*):
83
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
84
+ pretraining.
85
+ `attention_factor` (`float`, *optional*):
86
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
87
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
88
+ `factor` field to infer the suggested value.
89
+ `beta_fast` (`float`, *optional*):
90
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
91
+ ramp function. If unspecified, it defaults to 32.
92
+ `beta_slow` (`float`, *optional*):
93
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
94
+ ramp function. If unspecified, it defaults to 1.
95
+ `short_factor` (`List[float]`, *optional*):
96
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
97
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
98
+ size divided by the number of attention heads divided by 2
99
+ `long_factor` (`List[float]`, *optional*):
100
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
101
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
102
+ size divided by the number of attention heads divided by 2
103
+ `low_freq_factor` (`float`, *optional*):
104
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
105
+ `high_freq_factor` (`float`, *optional*):
106
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
107
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
108
+ Whether to use sliding window attention.
109
+ sliding_window (`int`, *optional*, defaults to 4096):
110
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
111
+ max_window_layers (`int`, *optional*, defaults to 28):
112
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
113
+ attention_dropout (`float`, *optional*, defaults to 0.0):
114
+ The dropout ratio for the attention probabilities.
115
+
116
+ ```python
117
+ >>> from transformers import RWKV6Qwen2Model, RWKV6Qwen2Config
118
+
119
+ >>> # Initializing a RWKV6Qwen2 style configuration
120
+ >>> configuration = RWKV6Qwen2Config()
121
+
122
+ >>> # Initializing a model from the RWKV6Qwen2-7B style configuration
123
+ >>> model = RWKV6Qwen2Model(configuration)
124
+
125
+ >>> # Accessing the model configuration
126
+ >>> configuration = model.config
127
+ ```"""
128
+
129
+ model_type = "rwkv6qwen2"
130
+ keys_to_ignore_at_inference = ["past_key_values"]
131
+
132
+ def __init__(
133
+ self,
134
+ vocab_size=151936,
135
+ hidden_size=4096,
136
+ intermediate_size=22016,
137
+ num_hidden_layers=32,
138
+ num_attention_heads=32,
139
+ num_key_value_heads=32,
140
+ lora_rank_tokenshift=None,
141
+ lora_rank_decay=None,
142
+ hidden_act="silu",
143
+ max_position_embeddings=32768,
144
+ initializer_range=0.02,
145
+ rms_norm_eps=1e-6,
146
+ use_cache=True,
147
+ tie_word_embeddings=False,
148
+ use_rope=False,
149
+ rope_theta=10000.0,
150
+ rope_scaling=None,
151
+ use_sliding_window=False,
152
+ sliding_window=4096,
153
+ max_window_layers=28,
154
+ attention_dropout=0.0,
155
+ attention_bias=True,
156
+ attention_output_bias=False,
157
+ gate_rank_type=1,
158
+ lora_rank_gate=None,
159
+ balance_state=True,
160
+ groupnorm_att=False,
161
+ use_tokenshift=True,
162
+ **kwargs,
163
+ ):
164
+ self.vocab_size = vocab_size
165
+ self.max_position_embeddings = max_position_embeddings
166
+ self.hidden_size = hidden_size
167
+ self.intermediate_size = intermediate_size
168
+ self.num_hidden_layers = num_hidden_layers
169
+ self.num_attention_heads = num_attention_heads
170
+ self.use_sliding_window = use_sliding_window
171
+ self.sliding_window = sliding_window if use_sliding_window else None
172
+ self.max_window_layers = max_window_layers
173
+
174
+ # for backward compatibility
175
+ if num_key_value_heads is None:
176
+ num_key_value_heads = num_attention_heads
177
+
178
+ self.num_key_value_heads = num_key_value_heads
179
+ self.lora_rank_tokenshift = lora_rank_tokenshift
180
+ self.lora_rank_decay = lora_rank_decay
181
+ self.hidden_act = hidden_act
182
+ self.initializer_range = initializer_range
183
+ self.rms_norm_eps = rms_norm_eps
184
+ self.use_cache = use_cache
185
+ self.use_rope = use_rope
186
+ self.rope_theta = rope_theta
187
+ self.rope_scaling = rope_scaling
188
+ self.attention_dropout = attention_dropout
189
+ # Validate the correctness of rotary position embeddings parameters
190
+ # BC: if there is a 'type' field, move it to 'rope_type'.
191
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
192
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
193
+ rope_config_validation(self)
194
+
195
+ self.attention_bias = attention_bias
196
+ self.attention_output_bias = attention_output_bias
197
+ self.gate_rank_type = gate_rank_type
198
+ self.lora_rank_gate = lora_rank_gate
199
+ self.balance_state = balance_state
200
+ self.groupnorm_att = groupnorm_att
201
+ self.use_tokenshift = use_tokenshift
202
+
203
+ super().__init__(
204
+ tie_word_embeddings=tie_word_embeddings,
205
+ **kwargs,
206
+ )
examine_ckpt.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import math
3
+ import torch
4
+ from collections import OrderedDict
5
+ import re
6
+ from safetensors.torch import load_file
7
+
8
+ if len(sys.argv) != 2:
9
+ print(f"Examines checkpoint keys")
10
+ print("Usage: python examine_ckpt.py in_file")
11
+ exit()
12
+
13
+ model_path = sys.argv[1]
14
+
15
+ print("Loading file...")
16
+ if model_path.lower().endswith('.safetensors'):
17
+ state_dict = load_file(model_path)
18
+ else:
19
+ state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
20
+
21
+ for name, p in state_dict.items():
22
+ if p.numel() == 0:
23
+ print(name, p.dtype, p.shape)
24
+ else:
25
+ print(name, p.dtype, p.shape, float(p.min()), float(p.max()))
generate.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ import torch
4
+ torch.backends.cudnn.benchmark = True
5
+ torch.backends.cudnn.allow_tf32 = True
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+
8
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
9
+ from configs import parse_cmdline_configs
10
+ from pydoc import locate
11
+
12
+ from dataclasses import dataclass
13
+ from typing import Any, Callable
14
+
15
+ moby_dick = """Call me Ishmael. Some years ago—never mind how long precisely—having little or no money in my purse, and nothing particular to interest me on shore, I thought I would sail about a little and see the watery part of the world. It is a way I have of driving off the spleen and regulating the circulation. Whenever I find myself growing grim about the mouth; whenever it is a damp, drizzly November in my soul; whenever I find myself involuntarily pausing before coffin warehouses, and bringing up the rear of every funeral I meet; and especially whenever my hypos get such an upper hand of me, that it requires a strong moral principle to prevent me from deliberately stepping into the street, and methodically knocking people’s hats off—then, I account it high time to get to sea as soon as I can. This is my substitute for pistol and ball. With a philosophical flourish Cato throws himself upon his sword; I quietly take to the ship. There is nothing surprising in this. If they but knew it, almost all men in their degree, some time or other, cherish very nearly the same feelings towards the ocean with me.
16
+
17
+ There now is your insular city of the Manhattoes, belted round by wharves as Indian isles by coral reefs—commerce surrounds it with her surf. Right and left, the streets take you waterward. Its extreme downtown is the battery, where that noble mole is washed by waves, and cooled by breezes, which a few hours previous were out of sight of land. Look at the crowds of water-gazers there.
18
+
19
+ Circumambulate the city of a dreamy Sabbath afternoon. Go from Corlears Hook to Coenties Slip, and from thence, by Whitehall, northward. What do you see?—Posted like silent sentinels all around the town, stand thousands upon thousands of mortal men fixed in ocean reveries. Some leaning against the spiles; some seated upon the pier-heads; some looking over the bulwarks of ships from China; some high aloft in the rigging, as if striving to get a still better seaward peep. But these are all landsmen; of week days pent up in lath and plaster—tied to counters, nailed to benches, clinched to desks. How then is this? Are the green fields gone? What do they here?
20
+
21
+ But look! here come more crowds, pacing straight for the water, and seemingly bound for a dive. Strange! Nothing will content them but the extremest limit of the land; loitering under the shady lee of yonder warehouses will not suffice. No. They must get just as nigh the water as they possibly can without falling in. And there they stand—miles of them—leagues. Inlanders all, they come from lanes and alleys, streets and avenues—north, east, south, and west. Yet here they all unite. Tell me, does the magnetic virtue of the needles of the compasses of all those ships attract them thither?
22
+
23
+ Once more. Say you are in the country; in some high land of lakes. Take almost any path you please, and ten to one it carries you down in a dale, and leaves you there by a pool in the stream. There is magic in it. Let the most absent-minded of men be plunged in his deepest reveries—stand that man on his legs, set his feet a-going, and he will infallibly lead you to water, if water there be in all that region. Should you ever be athirst in the great American desert, try this experiment, if your caravan happen to be supplied with a metaphysical professor. Yes, as every one knows, meditation and water are wedded for ever.
24
+
25
+ But here is an artist. He desires to paint you the dreamiest, shadiest, quietest, most enchanting bit of romantic landscape in all the valley of the Saco. What is the chief element he employs? There stand his trees, each with a hollow trunk, as if a hermit and a crucifix were within; and here sleeps his meadow, and there sleep his cattle; and up from yonder cottage goes a sleepy smoke. Deep into distant woodlands winds a mazy way, reaching to overlapping spurs of mountains bathed in their hill-side blue. But though the picture lies thus tranced, and though this pine-tree shakes down its sighs like leaves upon this shepherd’s head, yet all were vain, unless the shepherd’s eye were fixed upon the magic stream before him. Go visit the Prairies in June, when for scores on scores of miles you wade knee-deep among Tiger-lilies—what is the one charm wanting?—Water—there is not a drop of water there! Were Niagara but a cataract of sand, would you travel your thousand miles to see it? Why did the poor poet of Tennessee, upon suddenly receiving two handfuls of silver, deliberate whether to buy him a coat, which he sadly needed, or invest his money in a pedestrian trip to Rockaway Beach? Why is almost every robust healthy boy with a robust healthy soul in him, at some time or other crazy to go to sea? Why upon your first voyage as a passenger, did you yourself feel such a mystical vibration, when first told that you and your ship were now out of sight of land? Why did the old Persians hold the sea holy? Why did the Greeks give it a separate deity, and own brother of Jove? Surely all this is not without meaning. And still deeper the meaning of that story of Narcissus, who because he could not grasp the tormenting, mild image he saw in the fountain, plunged into it and was drowned. But that same image, we ourselves see in all rivers and oceans. It is the image of the ungraspable phantom of life; and this is the key to it all.
26
+
27
+ Now, when I say that I am in the habit of going to sea whenever I begin to grow hazy about the eyes, and begin to be over conscious of my lungs, I do not mean to have it inferred that I ever go to sea as a passenger. For to go as a passenger you must needs have a purse, and a purse is but a rag unless you have something in it. Besides, passengers get sea-sick—grow quarrelsome—don’t sleep of nights—do not enjoy themselves much, as a general thing;—no, I never go as a passenger; nor, though I am something of a salt, do I ever go to sea as a Commodore, or a Captain, or a Cook. I abandon the glory and distinction of such offices to those who like them. For my part, I abominate all honorable respectable toils, trials, and tribulations of every kind whatsoever. It is quite as much as I can do to take care of myself, without taking care of ships, barques, brigs, schooners, and what not. And as for going as cook,—though I confess there is considerable glory in that, a cook being a sort of officer on ship-board—yet, somehow, I never fancied broiling fowls;—though once broiled, judiciously buttered, and judgmatically salted and peppered, there is no one who will speak more respectfully, not to say reverentially, of a broiled fowl than I will. It is out of the idolatrous dotings of the old Egyptians upon broiled ibis and roasted river horse, that you see the mummies of those creatures in their huge bake-houses the pyramids.
28
+
29
+ No, when I go to sea, I go as a simple sailor, right before the mast, plumb down into the forecastle, aloft there to the royal mast-head. True, they rather order me about some, and make me jump from spar to spar, like a grasshopper in a May meadow. And at first, this sort of thing is unpleasant enough. It touches one’s sense of honor, particularly if you come of an old established family in the land, the Van Rensselaers, or Randolphs, or Hardicanutes. And more than all, if just previous to putting your hand into the tar-pot, you have been lording it as a country schoolmaster, making the tallest boys stand in awe of you. The transition is a keen one, I assure you, from a schoolmaster to a sailor, and requires a strong decoction of Seneca and the Stoics to enable you to grin and bear it. But even this wears off in time.
30
+
31
+ What of it, if some old hunks of a sea-captain orders me to get a broom and sweep down the decks? What does that indignity amount to, weighed, I mean, in the scales of the New Testament? Do you think the archangel Gabriel thinks anything the less of me, because I promptly and respectfully obey that old hunks in that particular instance? Who ain’t a slave? Tell me that. Well, then, however the old sea-captains may order me about—however they may thump and punch me about, I have the satisfaction of knowing that it is all right; that everybody else is one way or other served in much the same way—either in a physical or metaphysical point of view, that is; and so the universal thump is passed round, and all hands should rub each other’s shoulder-blades, and be content.
32
+
33
+ Again, I always go to sea as a sailor, because they make a point of paying me for my trouble, whereas they never pay passengers a single penny that I ever heard of. On the contrary, passengers themselves must pay. And there is all the difference in the world between paying and being paid. The act of paying is perhaps the most uncomfortable infliction that the two orchard thieves entailed upon us. But being paid,—what will compare with it? The urbane activity with which a man receives money is really marvellous, considering that we so earnestly believe money to be the root of all earthly ills, and that on no account can a monied man enter heaven. Ah! how cheerfully we consign ourselves to perdition!
34
+
35
+ Finally, I always go to sea as a sailor, because """
36
+
37
+
38
+ @dataclass
39
+ class CLI_Config:
40
+ tokenizer_path: str
41
+ model_path: str
42
+ attn_path: str = 'rwkv6attn.RWKV6Attention'
43
+ prompt:str = 'How many quarts are in a gallon?'
44
+ max_len:int = 30
45
+ attempts:int = 1
46
+ precision: int | str = 'bf16'
47
+ attn_classes_path: str = 'transformers.models.qwen2.modeling_qwen2.QWEN2_ATTENTION_CLASSES' # 'transformers.models.llama.modeling_llama.LLAMA_ATTENTION_CLASSES'
48
+ seed: int | None = None
49
+ train:Any = None
50
+
51
+ config, errors = parse_cmdline_configs(sys.argv[1:], CLI_Config)
52
+ if errors != '':
53
+ print(errors)
54
+ exit()
55
+
56
+ match config.precision:
57
+ case 32:
58
+ dtype = torch.float32
59
+ case '32':
60
+ dtype = torch.float32
61
+ case 16:
62
+ dtype = torch.float16
63
+ case '16':
64
+ dtype = torch.float16
65
+ case 'bf16':
66
+ dtype = torch.bfloat16
67
+ case _:
68
+ print("Bad precision type specified")
69
+ exit()
70
+
71
+ # avoid 1000 huggingface warnings "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...""
72
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
73
+
74
+ print(f'Loading model - {config.model_path}')
75
+
76
+ model_config = AutoConfig.from_pretrained(config.model_path, trust_remote_code=True)
77
+
78
+ # if config.model_path.startswith('.'):
79
+ # # replace attention classes
80
+ # ReplacementSelfAttentionType = locate(config.attn_path)
81
+ # assert isinstance(ReplacementSelfAttentionType, Callable)
82
+ # attn_classes_dict = locate(config.attn_classes_path)
83
+ # assert isinstance(attn_classes_dict, dict), 'could not find attention classes dict at path provided'
84
+ # for key in list(attn_classes_dict.keys()):
85
+ # attn_classes_dict[key] = ReplacementSelfAttentionType
86
+
87
+ model = AutoModelForCausalLM.from_pretrained(config.model_path, config=model_config, torch_dtype=dtype, device_map='cuda', trust_remote_code=True)
88
+
89
+ tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path, trust_remote_code=True)
90
+
91
+ #device = 'cuda'
92
+ #model = model.to(device=device, dtype=dtype)
93
+ model.eval()
94
+
95
+ if config.seed is None:
96
+ config.seed = 1234
97
+
98
+ from transformers import AutoTokenizer, Qwen2ForCausalLM, set_seed
99
+
100
+ set_seed(config.seed)
101
+
102
+ text = config.prompt
103
+
104
+ messages = [
105
+ {"role": "system", "content": "You are a helpful assistant."},
106
+ {"role": "user", "content": config.prompt}
107
+ ]
108
+ text = tokenizer.apply_chat_template(
109
+ messages,
110
+ tokenize=False,
111
+ add_generation_prompt=True
112
+ )
113
+ inputs = tokenizer(text, return_tensors="pt").to('cuda')
114
+
115
+ # Generate
116
+ for i in range(config.attempts):
117
+ print(f"Attempt {i+1}:")
118
+ generate_ids = model.generate(inputs.input_ids, max_new_tokens=config.max_len, use_cache=True, do_sample=True, temperature=1.0, top_p=1.0)#, typical_p=0.95)#top_p=0.7, repetition_penalty=0.25)
119
+ print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, use_cache=False)[0])
generation_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "pad_token_id": 151643,
4
+ "do_sample": true,
5
+ "eos_token_id": [
6
+ 151645,
7
+ 151643
8
+ ],
9
+ "repetition_penalty": 1.05,
10
+ "temperature": 0.7,
11
+ "top_p": 0.8,
12
+ "top_k": 20,
13
+ "transformers_version": "4.37.0"
14
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:662ccdc789fe510fdf88cda5727f77061d133d9d31f0591d7a008b8e8add6b04
3
+ size 4955133312
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9cdf366e3cca4c985365bb2571668ce7d2d3649071bf270e1166a18208d55c4
3
+ size 4865370824
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:367ae7e5f35819bf5d72cfb4159fd7d560a7cf943380d53a0e64d8e50b38d61c
3
+ size 4865370920
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfe5f39f1ff90b14ab41d7b12f852b3e506c40b2230a40735bf181360de839bc
3
+ size 1497374264
model.safetensors.index.json ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 16183172096
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00004.safetensors",
7
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
8
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
9
+ "model.layers.0.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
10
+ "model.layers.0.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
11
+ "model.layers.0.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
12
+ "model.layers.0.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
13
+ "model.layers.0.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
14
+ "model.layers.0.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
15
+ "model.layers.0.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
16
+ "model.layers.0.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
17
+ "model.layers.0.self_attn.time_decay": "model-00001-of-00004.safetensors",
18
+ "model.layers.0.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
19
+ "model.layers.0.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
20
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
21
+ "model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
22
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
23
+ "model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
24
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
25
+ "model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
26
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
27
+ "model.layers.0.self_attn.gate.weight": "model-00001-of-00004.safetensors",
28
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
29
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
30
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
31
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
32
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
33
+ "model.layers.1.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
34
+ "model.layers.1.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
35
+ "model.layers.1.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
36
+ "model.layers.1.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
37
+ "model.layers.1.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
38
+ "model.layers.1.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
39
+ "model.layers.1.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
40
+ "model.layers.1.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
41
+ "model.layers.1.self_attn.time_decay": "model-00001-of-00004.safetensors",
42
+ "model.layers.1.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
43
+ "model.layers.1.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
44
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
45
+ "model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
46
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
47
+ "model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
48
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
49
+ "model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
50
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
51
+ "model.layers.1.self_attn.gate.weight": "model-00001-of-00004.safetensors",
52
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
53
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
54
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
55
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
56
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
57
+ "model.layers.2.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
58
+ "model.layers.2.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
59
+ "model.layers.2.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
60
+ "model.layers.2.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
61
+ "model.layers.2.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
62
+ "model.layers.2.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
63
+ "model.layers.2.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
64
+ "model.layers.2.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
65
+ "model.layers.2.self_attn.time_decay": "model-00001-of-00004.safetensors",
66
+ "model.layers.2.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
67
+ "model.layers.2.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
68
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
69
+ "model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
70
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
71
+ "model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
72
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
73
+ "model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
74
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
75
+ "model.layers.2.self_attn.gate.weight": "model-00001-of-00004.safetensors",
76
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
77
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
78
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
79
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
80
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
81
+ "model.layers.3.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
82
+ "model.layers.3.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
83
+ "model.layers.3.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
84
+ "model.layers.3.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
85
+ "model.layers.3.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
86
+ "model.layers.3.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
87
+ "model.layers.3.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
88
+ "model.layers.3.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
89
+ "model.layers.3.self_attn.time_decay": "model-00001-of-00004.safetensors",
90
+ "model.layers.3.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
91
+ "model.layers.3.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
92
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
93
+ "model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
94
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
95
+ "model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
96
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
97
+ "model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
98
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
99
+ "model.layers.3.self_attn.gate.weight": "model-00001-of-00004.safetensors",
100
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
101
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
102
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
103
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
104
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
105
+ "model.layers.4.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
106
+ "model.layers.4.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
107
+ "model.layers.4.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
108
+ "model.layers.4.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
109
+ "model.layers.4.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
110
+ "model.layers.4.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
111
+ "model.layers.4.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
112
+ "model.layers.4.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
113
+ "model.layers.4.self_attn.time_decay": "model-00001-of-00004.safetensors",
114
+ "model.layers.4.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
115
+ "model.layers.4.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
116
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
117
+ "model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
118
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
119
+ "model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
120
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
121
+ "model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
122
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
123
+ "model.layers.4.self_attn.gate.weight": "model-00001-of-00004.safetensors",
124
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
125
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
126
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
127
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
128
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
129
+ "model.layers.5.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
130
+ "model.layers.5.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
131
+ "model.layers.5.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
132
+ "model.layers.5.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
133
+ "model.layers.5.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
134
+ "model.layers.5.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
135
+ "model.layers.5.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
136
+ "model.layers.5.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
137
+ "model.layers.5.self_attn.time_decay": "model-00001-of-00004.safetensors",
138
+ "model.layers.5.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
139
+ "model.layers.5.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
140
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
141
+ "model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
142
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
143
+ "model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
144
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
145
+ "model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
146
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
147
+ "model.layers.5.self_attn.gate.weight": "model-00001-of-00004.safetensors",
148
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
149
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
150
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
151
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
152
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
153
+ "model.layers.6.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
154
+ "model.layers.6.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
155
+ "model.layers.6.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
156
+ "model.layers.6.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
157
+ "model.layers.6.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
158
+ "model.layers.6.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
159
+ "model.layers.6.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
160
+ "model.layers.6.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
161
+ "model.layers.6.self_attn.time_decay": "model-00001-of-00004.safetensors",
162
+ "model.layers.6.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
163
+ "model.layers.6.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
164
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
165
+ "model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
166
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
167
+ "model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
168
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
169
+ "model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
170
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
171
+ "model.layers.6.self_attn.gate.weight": "model-00001-of-00004.safetensors",
172
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
173
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
174
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
175
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
176
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
177
+ "model.layers.7.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
178
+ "model.layers.7.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
179
+ "model.layers.7.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
180
+ "model.layers.7.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
181
+ "model.layers.7.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
182
+ "model.layers.7.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
183
+ "model.layers.7.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
184
+ "model.layers.7.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
185
+ "model.layers.7.self_attn.time_decay": "model-00001-of-00004.safetensors",
186
+ "model.layers.7.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
187
+ "model.layers.7.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
188
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
189
+ "model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
190
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
191
+ "model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
192
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
193
+ "model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
194
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
195
+ "model.layers.7.self_attn.gate.weight": "model-00001-of-00004.safetensors",
196
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
197
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
198
+ "model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
199
+ "model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
200
+ "model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
201
+ "model.layers.8.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
202
+ "model.layers.8.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
203
+ "model.layers.8.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
204
+ "model.layers.8.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
205
+ "model.layers.8.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
206
+ "model.layers.8.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
207
+ "model.layers.8.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
208
+ "model.layers.8.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
209
+ "model.layers.8.self_attn.time_decay": "model-00002-of-00004.safetensors",
210
+ "model.layers.8.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
211
+ "model.layers.8.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
212
+ "model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
213
+ "model.layers.8.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
214
+ "model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
215
+ "model.layers.8.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
216
+ "model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
217
+ "model.layers.8.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
218
+ "model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
219
+ "model.layers.8.self_attn.gate.weight": "model-00002-of-00004.safetensors",
220
+ "model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
221
+ "model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
222
+ "model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
223
+ "model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
224
+ "model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
225
+ "model.layers.9.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
226
+ "model.layers.9.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
227
+ "model.layers.9.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
228
+ "model.layers.9.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
229
+ "model.layers.9.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
230
+ "model.layers.9.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
231
+ "model.layers.9.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
232
+ "model.layers.9.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
233
+ "model.layers.9.self_attn.time_decay": "model-00002-of-00004.safetensors",
234
+ "model.layers.9.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
235
+ "model.layers.9.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
236
+ "model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
237
+ "model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
238
+ "model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
239
+ "model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
240
+ "model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
241
+ "model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
242
+ "model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
243
+ "model.layers.9.self_attn.gate.weight": "model-00002-of-00004.safetensors",
244
+ "model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
245
+ "model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
246
+ "model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
247
+ "model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
248
+ "model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
249
+ "model.layers.10.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
250
+ "model.layers.10.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
251
+ "model.layers.10.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
252
+ "model.layers.10.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
253
+ "model.layers.10.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
254
+ "model.layers.10.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
255
+ "model.layers.10.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
256
+ "model.layers.10.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
257
+ "model.layers.10.self_attn.time_decay": "model-00002-of-00004.safetensors",
258
+ "model.layers.10.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
259
+ "model.layers.10.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
260
+ "model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
261
+ "model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
262
+ "model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
263
+ "model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
264
+ "model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
265
+ "model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
266
+ "model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
267
+ "model.layers.10.self_attn.gate.weight": "model-00002-of-00004.safetensors",
268
+ "model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
269
+ "model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
270
+ "model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
271
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
272
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
273
+ "model.layers.11.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
274
+ "model.layers.11.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
275
+ "model.layers.11.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
276
+ "model.layers.11.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
277
+ "model.layers.11.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
278
+ "model.layers.11.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
279
+ "model.layers.11.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
280
+ "model.layers.11.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
281
+ "model.layers.11.self_attn.time_decay": "model-00002-of-00004.safetensors",
282
+ "model.layers.11.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
283
+ "model.layers.11.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
284
+ "model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
285
+ "model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
286
+ "model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
287
+ "model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
288
+ "model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
289
+ "model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
290
+ "model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
291
+ "model.layers.11.self_attn.gate.weight": "model-00002-of-00004.safetensors",
292
+ "model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
293
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
294
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
295
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
296
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
297
+ "model.layers.12.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
298
+ "model.layers.12.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
299
+ "model.layers.12.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
300
+ "model.layers.12.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
301
+ "model.layers.12.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
302
+ "model.layers.12.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
303
+ "model.layers.12.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
304
+ "model.layers.12.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
305
+ "model.layers.12.self_attn.time_decay": "model-00002-of-00004.safetensors",
306
+ "model.layers.12.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
307
+ "model.layers.12.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
308
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
309
+ "model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
310
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
311
+ "model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
312
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
313
+ "model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
314
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
315
+ "model.layers.12.self_attn.gate.weight": "model-00002-of-00004.safetensors",
316
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
317
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
318
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
319
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
320
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
321
+ "model.layers.13.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
322
+ "model.layers.13.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
323
+ "model.layers.13.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
324
+ "model.layers.13.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
325
+ "model.layers.13.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
326
+ "model.layers.13.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
327
+ "model.layers.13.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
328
+ "model.layers.13.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
329
+ "model.layers.13.self_attn.time_decay": "model-00002-of-00004.safetensors",
330
+ "model.layers.13.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
331
+ "model.layers.13.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
332
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
333
+ "model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
334
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
335
+ "model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
336
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
337
+ "model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
338
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
339
+ "model.layers.13.self_attn.gate.weight": "model-00002-of-00004.safetensors",
340
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
341
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
342
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
343
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
344
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
345
+ "model.layers.14.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
346
+ "model.layers.14.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
347
+ "model.layers.14.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
348
+ "model.layers.14.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
349
+ "model.layers.14.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
350
+ "model.layers.14.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
351
+ "model.layers.14.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
352
+ "model.layers.14.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
353
+ "model.layers.14.self_attn.time_decay": "model-00002-of-00004.safetensors",
354
+ "model.layers.14.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
355
+ "model.layers.14.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
356
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
357
+ "model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
358
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
359
+ "model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
360
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
361
+ "model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
362
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
363
+ "model.layers.14.self_attn.gate.weight": "model-00002-of-00004.safetensors",
364
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
365
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
366
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
367
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
368
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
369
+ "model.layers.15.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
370
+ "model.layers.15.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
371
+ "model.layers.15.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
372
+ "model.layers.15.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
373
+ "model.layers.15.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
374
+ "model.layers.15.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
375
+ "model.layers.15.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
376
+ "model.layers.15.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
377
+ "model.layers.15.self_attn.time_decay": "model-00002-of-00004.safetensors",
378
+ "model.layers.15.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
379
+ "model.layers.15.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
380
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
381
+ "model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
382
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
383
+ "model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
384
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
385
+ "model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
386
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
387
+ "model.layers.15.self_attn.gate.weight": "model-00002-of-00004.safetensors",
388
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
389
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
390
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
391
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
392
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
393
+ "model.layers.16.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
394
+ "model.layers.16.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
395
+ "model.layers.16.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
396
+ "model.layers.16.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
397
+ "model.layers.16.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
398
+ "model.layers.16.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
399
+ "model.layers.16.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
400
+ "model.layers.16.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
401
+ "model.layers.16.self_attn.time_decay": "model-00002-of-00004.safetensors",
402
+ "model.layers.16.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
403
+ "model.layers.16.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
404
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
405
+ "model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
406
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
407
+ "model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
408
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
409
+ "model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
410
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
411
+ "model.layers.16.self_attn.gate.weight": "model-00002-of-00004.safetensors",
412
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
413
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
414
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
415
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
416
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
417
+ "model.layers.17.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
418
+ "model.layers.17.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
419
+ "model.layers.17.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
420
+ "model.layers.17.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
421
+ "model.layers.17.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
422
+ "model.layers.17.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
423
+ "model.layers.17.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
424
+ "model.layers.17.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
425
+ "model.layers.17.self_attn.time_decay": "model-00002-of-00004.safetensors",
426
+ "model.layers.17.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
427
+ "model.layers.17.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
428
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
429
+ "model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
430
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
431
+ "model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
432
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
433
+ "model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
434
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
435
+ "model.layers.17.self_attn.gate.weight": "model-00002-of-00004.safetensors",
436
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
437
+ "model.layers.17.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
438
+ "model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
439
+ "model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
440
+ "model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
441
+ "model.layers.18.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
442
+ "model.layers.18.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
443
+ "model.layers.18.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
444
+ "model.layers.18.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
445
+ "model.layers.18.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
446
+ "model.layers.18.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
447
+ "model.layers.18.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
448
+ "model.layers.18.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
449
+ "model.layers.18.self_attn.time_decay": "model-00003-of-00004.safetensors",
450
+ "model.layers.18.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
451
+ "model.layers.18.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
452
+ "model.layers.18.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
453
+ "model.layers.18.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
454
+ "model.layers.18.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
455
+ "model.layers.18.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
456
+ "model.layers.18.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
457
+ "model.layers.18.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
458
+ "model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
459
+ "model.layers.18.self_attn.gate.weight": "model-00003-of-00004.safetensors",
460
+ "model.layers.18.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
461
+ "model.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
462
+ "model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
463
+ "model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
464
+ "model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
465
+ "model.layers.19.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
466
+ "model.layers.19.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
467
+ "model.layers.19.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
468
+ "model.layers.19.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
469
+ "model.layers.19.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
470
+ "model.layers.19.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
471
+ "model.layers.19.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
472
+ "model.layers.19.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
473
+ "model.layers.19.self_attn.time_decay": "model-00003-of-00004.safetensors",
474
+ "model.layers.19.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
475
+ "model.layers.19.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
476
+ "model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
477
+ "model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
478
+ "model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
479
+ "model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
480
+ "model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
481
+ "model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
482
+ "model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
483
+ "model.layers.19.self_attn.gate.weight": "model-00003-of-00004.safetensors",
484
+ "model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
485
+ "model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
486
+ "model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
487
+ "model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
488
+ "model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
489
+ "model.layers.20.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
490
+ "model.layers.20.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
491
+ "model.layers.20.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
492
+ "model.layers.20.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
493
+ "model.layers.20.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
494
+ "model.layers.20.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
495
+ "model.layers.20.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
496
+ "model.layers.20.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
497
+ "model.layers.20.self_attn.time_decay": "model-00003-of-00004.safetensors",
498
+ "model.layers.20.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
499
+ "model.layers.20.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
500
+ "model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
501
+ "model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
502
+ "model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
503
+ "model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
504
+ "model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
505
+ "model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
506
+ "model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
507
+ "model.layers.20.self_attn.gate.weight": "model-00003-of-00004.safetensors",
508
+ "model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
509
+ "model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
510
+ "model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
511
+ "model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
512
+ "model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
513
+ "model.layers.21.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
514
+ "model.layers.21.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
515
+ "model.layers.21.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
516
+ "model.layers.21.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
517
+ "model.layers.21.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
518
+ "model.layers.21.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
519
+ "model.layers.21.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
520
+ "model.layers.21.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
521
+ "model.layers.21.self_attn.time_decay": "model-00003-of-00004.safetensors",
522
+ "model.layers.21.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
523
+ "model.layers.21.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
524
+ "model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
525
+ "model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
526
+ "model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
527
+ "model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
528
+ "model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
529
+ "model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
530
+ "model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
531
+ "model.layers.21.self_attn.gate.weight": "model-00003-of-00004.safetensors",
532
+ "model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
533
+ "model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
534
+ "model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
535
+ "model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
536
+ "model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
537
+ "model.layers.22.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
538
+ "model.layers.22.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
539
+ "model.layers.22.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
540
+ "model.layers.22.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
541
+ "model.layers.22.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
542
+ "model.layers.22.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
543
+ "model.layers.22.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
544
+ "model.layers.22.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
545
+ "model.layers.22.self_attn.time_decay": "model-00003-of-00004.safetensors",
546
+ "model.layers.22.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
547
+ "model.layers.22.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
548
+ "model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
549
+ "model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
550
+ "model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
551
+ "model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
552
+ "model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
553
+ "model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
554
+ "model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
555
+ "model.layers.22.self_attn.gate.weight": "model-00003-of-00004.safetensors",
556
+ "model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
557
+ "model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
558
+ "model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
559
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
560
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
561
+ "model.layers.23.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
562
+ "model.layers.23.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
563
+ "model.layers.23.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
564
+ "model.layers.23.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
565
+ "model.layers.23.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
566
+ "model.layers.23.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
567
+ "model.layers.23.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
568
+ "model.layers.23.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
569
+ "model.layers.23.self_attn.time_decay": "model-00003-of-00004.safetensors",
570
+ "model.layers.23.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
571
+ "model.layers.23.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
572
+ "model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
573
+ "model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
574
+ "model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
575
+ "model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
576
+ "model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
577
+ "model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
578
+ "model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
579
+ "model.layers.23.self_attn.gate.weight": "model-00003-of-00004.safetensors",
580
+ "model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
581
+ "model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
582
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
583
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
584
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
585
+ "model.layers.24.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
586
+ "model.layers.24.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
587
+ "model.layers.24.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
588
+ "model.layers.24.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
589
+ "model.layers.24.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
590
+ "model.layers.24.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
591
+ "model.layers.24.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
592
+ "model.layers.24.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
593
+ "model.layers.24.self_attn.time_decay": "model-00003-of-00004.safetensors",
594
+ "model.layers.24.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
595
+ "model.layers.24.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
596
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
597
+ "model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
598
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
599
+ "model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
600
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
601
+ "model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
602
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
603
+ "model.layers.24.self_attn.gate.weight": "model-00003-of-00004.safetensors",
604
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
605
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
606
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
607
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
608
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
609
+ "model.layers.25.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
610
+ "model.layers.25.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
611
+ "model.layers.25.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
612
+ "model.layers.25.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
613
+ "model.layers.25.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
614
+ "model.layers.25.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
615
+ "model.layers.25.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
616
+ "model.layers.25.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
617
+ "model.layers.25.self_attn.time_decay": "model-00003-of-00004.safetensors",
618
+ "model.layers.25.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
619
+ "model.layers.25.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
620
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
621
+ "model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
622
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
623
+ "model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
624
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
625
+ "model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
626
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
627
+ "model.layers.25.self_attn.gate.weight": "model-00003-of-00004.safetensors",
628
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
629
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
630
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
631
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
632
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
633
+ "model.layers.26.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
634
+ "model.layers.26.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
635
+ "model.layers.26.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
636
+ "model.layers.26.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
637
+ "model.layers.26.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
638
+ "model.layers.26.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
639
+ "model.layers.26.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
640
+ "model.layers.26.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
641
+ "model.layers.26.self_attn.time_decay": "model-00003-of-00004.safetensors",
642
+ "model.layers.26.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
643
+ "model.layers.26.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
644
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
645
+ "model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
646
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
647
+ "model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
648
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
649
+ "model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
650
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
651
+ "model.layers.26.self_attn.gate.weight": "model-00003-of-00004.safetensors",
652
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
653
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
654
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
655
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
656
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
657
+ "model.layers.27.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
658
+ "model.layers.27.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
659
+ "model.layers.27.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
660
+ "model.layers.27.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
661
+ "model.layers.27.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
662
+ "model.layers.27.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
663
+ "model.layers.27.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
664
+ "model.layers.27.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
665
+ "model.layers.27.self_attn.time_decay": "model-00003-of-00004.safetensors",
666
+ "model.layers.27.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
667
+ "model.layers.27.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
668
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
669
+ "model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
670
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
671
+ "model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
672
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
673
+ "model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
674
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
675
+ "model.layers.27.self_attn.gate.weight": "model-00003-of-00004.safetensors",
676
+ "model.layers.27.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
677
+ "model.layers.27.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
678
+ "model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
679
+ "model.norm.weight": "model-00004-of-00004.safetensors",
680
+ "lm_head.weight": "model-00004-of-00004.safetensors"
681
+ }
682
+ }
modeling_rwkv6qwen2.py ADDED
@@ -0,0 +1,1336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """PyTorch RWKV6Qwen2 model."""
21
+
22
+ import math
23
+ import inspect
24
+ from typing import List, Optional, Tuple, Union, Dict, Any
25
+
26
+ import torch
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ import torch.nn.functional as F
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.cache_utils import Cache, StaticCache, DynamicCache
33
+ from transformers.generation import GenerationMixin
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPast,
36
+ CausalLMOutputWithPast,
37
+ QuestionAnsweringModelOutput,
38
+ SequenceClassifierOutputWithPast,
39
+ TokenClassifierOutput,
40
+ )
41
+ from transformers.modeling_utils import PreTrainedModel
42
+ from transformers.utils import (
43
+ add_code_sample_docstrings,
44
+ add_start_docstrings,
45
+ add_start_docstrings_to_model_forward,
46
+ is_flash_attn_2_available,
47
+ is_flash_attn_greater_or_equal_2_10,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from .configuration_rwkv6qwen2 import RWKV6Qwen2Config
52
+
53
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2RMSNorm, repeat_kv
54
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+
59
+ _CHECKPOINT_FOR_DOC = "RWKV/RWKV6Qwen2-7B"
60
+ _CONFIG_FOR_DOC = "RWKV6Qwen2Config"
61
+
62
+ class RWKV6State(Cache):
63
+ def __init__(self) -> None:
64
+ super().__init__()
65
+ self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
66
+ self.layer_kv_states: List[torch.Tensor] = []
67
+ self.layer_shift_states: List[torch.Tensor] = []
68
+
69
+ def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
70
+ """
71
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
72
+ sequence length.
73
+ """
74
+ if layer_idx < len(self):
75
+ return (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx])
76
+ else:
77
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
78
+
79
+ def __iter__(self):
80
+ """
81
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
82
+ keys and values
83
+ """
84
+ for layer_idx in range(len(self)):
85
+ yield (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx])
86
+
87
+ def __len__(self):
88
+ """
89
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
90
+ to the number of layers in the model.
91
+ """
92
+ return len(self.layer_kv_states)
93
+
94
+ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
95
+ """Given the sequence length of the new inputs, returns the usable length of the cache."""
96
+ # Linear Attention variants do not have a maximum length
97
+ return new_seq_length
98
+
99
+ def reorder_cache(self, beam_idx: torch.LongTensor):
100
+ """Reorders the cache for beam search, given the selected beam indices."""
101
+ raise NotImplementedError('Cannot reorder Linear Attention state')
102
+
103
+ def get_seq_length(self, layer_idx: int = 0) -> int:
104
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
105
+ return self._seen_tokens
106
+
107
+ def get_max_cache_shape(self) -> Optional[int]:
108
+ """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
109
+ return None
110
+
111
+ def get_max_length(self) -> Optional[int]:
112
+ """
113
+ Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
114
+ """
115
+ return None
116
+
117
+ # def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
118
+ # """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
119
+ # backward compatibility."""
120
+ # legacy_cache = ()
121
+ # for layer_idx in range(len(self)):
122
+ # legacy_cache += ((self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]),)
123
+ # return legacy_cache
124
+
125
+ # @classmethod
126
+ # #@deprecate_kwarg("num_hidden_layers", version="4.47.0")
127
+ # def from_legacy_cache(
128
+ # cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None, num_hidden_layers: int | None = None
129
+ # ) -> "RWKV6State":
130
+ # """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
131
+ # backward compatibility."""
132
+ # cache = cls()
133
+ # if past_key_values is not None:
134
+ # for layer_idx in range(len(past_key_values)):
135
+ # layer_kv_state, layer_shift_state = past_key_values[layer_idx]
136
+ # cache.update(layer_kv_state, layer_shift_state, layer_idx)
137
+ # return cache
138
+
139
+ def crop(self, max_length: int):
140
+ # can't implement this for linear attention variants
141
+ return
142
+
143
+ @torch.no_grad
144
+ def update(
145
+ self,
146
+ kv_state: torch.Tensor,
147
+ shift_state: torch.Tensor,
148
+ token_count: int,
149
+ layer_idx: int,
150
+ cache_kwargs: Optional[Dict[str, Any]] = None,
151
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
152
+ # Update the number of seen tokens
153
+ if layer_idx == 0:
154
+ self._seen_tokens += token_count
155
+
156
+ # Update the cache
157
+ # There may be skipped layers, fill them with empty lists
158
+ for _ in range(len(self.layer_kv_states), layer_idx + 1):
159
+ self.layer_kv_states.append(torch.zeros_like(kv_state).requires_grad_(False))
160
+ self.layer_shift_states.append(torch.zeros_like(shift_state).requires_grad_(False))
161
+ self.layer_kv_states[layer_idx].copy_(kv_state)
162
+ self.layer_shift_states[layer_idx].copy_(shift_state)
163
+
164
+ return self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]
165
+
166
+ # @deprecate_kwarg("num_hidden_layers", version="4.47.0")
167
+ # def batch_split(
168
+ # self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
169
+ # ) -> List["DynamicCache"]:
170
+ # """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
171
+ # `_split_model_inputs()` in `generation.utils`"""
172
+ # out = []
173
+ # for i in range(0, full_batch_size, split_size):
174
+ # current_split = DynamicCache()
175
+ # current_split._seen_tokens = self._seen_tokens
176
+ # current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
177
+ # current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
178
+ # out.append(current_split)
179
+ # return out
180
+
181
+ # @classmethod
182
+ # @deprecate_kwarg("num_hidden_layers", version="4.47.0")
183
+ # def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
184
+ # """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
185
+ # `generation.utils`"""
186
+ # cache = cls()
187
+ # for idx in range(len(splits[0])):
188
+ # key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
189
+ # value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
190
+ # if key_cache != []:
191
+ # layer_keys = torch.cat(key_cache, dim=0)
192
+ # layer_values = torch.cat(value_cache, dim=0)
193
+ # cache.update(layer_keys, layer_values, idx)
194
+ # return cache
195
+
196
+ # def batch_repeat_interleave(self, repeats: int):
197
+ # """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
198
+ # for layer_idx in range(len(self)):
199
+ # self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
200
+ # self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
201
+
202
+ # def batch_select_indices(self, indices: torch.Tensor):
203
+ # """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
204
+ # for layer_idx in range(len(self)):
205
+ # self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
206
+ # self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
207
+
208
+ try:
209
+ #from fla.ops.gla.chunk import chunk_gla
210
+ from fla.ops.gla.fused_recurrent import fused_recurrent_gla
211
+ except ImportError:
212
+ print("Required module is not installed. Please install it using the following commands:")
213
+ print("pip install -U git+https://github.com/fla-org/flash-linear-attention")
214
+ print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
215
+ print("pip install triton>=2.2.0")
216
+
217
+ class Qwen2RotaryEmbedding(nn.Module):
218
+ def __init__(self, config: RWKV6Qwen2Config, device=None):
219
+ super().__init__()
220
+ # BC: "rope_type" was originally "type"
221
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
222
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
223
+ else:
224
+ self.rope_type = "default"
225
+ self.max_seq_len_cached = config.max_position_embeddings
226
+ self.original_max_seq_len = config.max_position_embeddings
227
+
228
+ self.config = config
229
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
230
+
231
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
232
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
233
+ self.original_inv_freq = self.inv_freq
234
+
235
+ def _dynamic_frequency_update(self, position_ids, device):
236
+ """
237
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
238
+ 1 - growing beyond the cached sequence length (allow scaling)
239
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
240
+ """
241
+ seq_len = torch.max(position_ids) + 1
242
+ if seq_len > self.max_seq_len_cached: # growth
243
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
244
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
245
+ self.max_seq_len_cached = seq_len
246
+
247
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
248
+ # This .to() is needed if the model has been moved to a device after being initialized (because
249
+ # the buffer is automatically moved, but not the original copy)
250
+ self.original_inv_freq = self.original_inv_freq.to(device)
251
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
252
+ self.max_seq_len_cached = self.original_max_seq_len
253
+
254
+ @torch.no_grad()
255
+ def forward(self, x, position_ids):
256
+ if "dynamic" in self.rope_type:
257
+ self._dynamic_frequency_update(position_ids, device=x.device)
258
+
259
+ # Core RoPE block
260
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
261
+ position_ids_expanded = position_ids[:, None, :].float()
262
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
263
+ device_type = x.device.type
264
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
265
+ with torch.autocast(device_type=device_type, enabled=False):
266
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
267
+ emb = torch.cat((freqs, freqs), dim=-1)
268
+ cos = emb.cos()
269
+ sin = emb.sin()
270
+
271
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
272
+ cos = cos * self.attention_scaling
273
+ sin = sin * self.attention_scaling
274
+
275
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
276
+
277
+ def generate_rotary_embedding(max_seqlen:int, dim:int, theta:float = 10000.0, scale:float = 1):
278
+ #inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float).to(device) / dim))
279
+
280
+ angular_velocity = theta ** -(torch.arange(0, dim, 2, dtype=torch.float) / dim) / scale # frequencies from 1.0 ... 1/theta
281
+ angles = torch.outer(torch.arange(max_seqlen), angular_velocity)
282
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
283
+ emb = torch.cat((angles, angles), dim=-1)
284
+ return torch.stack([emb.cos(), emb.sin()], dim=0)
285
+ #return torch.polar(torch.ones_like(angles), angles)
286
+
287
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
288
+ def rotate_half(x):
289
+ """Rotates half the hidden dims of the input."""
290
+ x1 = x[..., : x.shape[-1] // 2]
291
+ x2 = x[..., x.shape[-1] // 2 :]
292
+ return torch.cat((-x2, x1), dim=-1)
293
+
294
+ # # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
295
+ # def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim:int=1):
296
+ # B, L = q.size(0), q.size(-2)
297
+ # cos = cos[:L].unsqueeze(0).expand(B,L,-1).unsqueeze(unsqueeze_dim)
298
+ # sin = sin[:L].unsqueeze(0).expand(B,L,-1).unsqueeze(unsqueeze_dim)
299
+ # q_embed = (q * cos) + (rotate_half(q) * sin)
300
+ # k_embed = (k * cos) + (rotate_half(k) * sin)
301
+ # return q_embed, k_embed
302
+
303
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
304
+ """Applies Rotary Position Embedding to the query and key tensors.
305
+
306
+ Args:
307
+ q (`torch.Tensor`): The query tensor.
308
+ k (`torch.Tensor`): The key tensor.
309
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
310
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
311
+ position_ids (`torch.Tensor`, *optional*):
312
+ Deprecated and unused.
313
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
314
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
315
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
316
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
317
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
318
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
319
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
320
+ Returns:
321
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
322
+ """
323
+ cos = cos.unsqueeze(unsqueeze_dim)
324
+ sin = sin.unsqueeze(unsqueeze_dim)
325
+ q_embed = (q * cos) + (rotate_half(q) * sin)
326
+ k_embed = (k * cos) + (rotate_half(k) * sin)
327
+ return q_embed, k_embed
328
+
329
+ def ortho_init(x, scale):
330
+ with torch.no_grad():
331
+ shape = x.shape
332
+ if len(shape) == 2:
333
+ gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
334
+ #nn.init.orthogonal_(x, gain=gain * scale)
335
+ x.copy_(nn.init.orthogonal_(torch.empty_like(x, dtype=torch.float32), gain=gain * scale))
336
+ elif len(shape) == 3:
337
+ gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
338
+ for i in range(shape[0]):
339
+ #nn.init.orthogonal_(x[i], gain=gain * scale)
340
+ x[i].copy_(nn.init.orthogonal_(torch.empty_like(x[i], dtype=torch.float32), gain=gain * scale))
341
+ else:
342
+ assert False
343
+ return x
344
+
345
+ class RWKV6Attention(nn.Module):
346
+ def __init__(self, config, layer_idx: Optional[int] = None):
347
+ super().__init__()
348
+ self.config = config
349
+ self.layer_idx = layer_idx
350
+
351
+ if layer_idx is None:
352
+ logger.warning_once(
353
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
354
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
355
+ "when creating this class."
356
+ )
357
+
358
+ self.hidden_size = config.hidden_size
359
+ self.num_heads = config.num_attention_heads
360
+ self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
361
+ self.num_key_value_heads = config.num_key_value_heads
362
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
363
+ self.attention_dropout = config.attention_dropout
364
+
365
+ n_layer = self.config.num_hidden_layers
366
+ n_embd = self.hidden_size
367
+ dim_att = self.num_heads * self.head_dim
368
+ layer_id = self.layer_idx
369
+
370
+ if self.hidden_size % self.num_heads != 0:
371
+ raise ValueError(
372
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
373
+ f" and `num_heads`: {self.num_heads})."
374
+ )
375
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
376
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
377
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
378
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, 'attention_output_bias', config.attention_bias))
379
+
380
+ calc_lora_rank = lambda exponent, multiplier: max(1, round(self.hidden_size ** exponent * multiplier / 32)) * 32
381
+
382
+ if config.gate_rank_type == 1:
383
+ self.gate = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
384
+ elif config.gate_rank_type == 2:
385
+ lora_rank_gate = config.lora_rank_gate or calc_lora_rank(0.8, 0.6)
386
+ self.g1 = nn.Parameter(torch.empty(n_embd, lora_rank_gate))
387
+ self.g2 = nn.Parameter(torch.empty(lora_rank_gate, n_embd))
388
+
389
+ if config.groupnorm_att:
390
+ self.ln_x = nn.GroupNorm(self.num_heads, dim_att, eps=self.head_dim * 1e-5)
391
+
392
+ with torch.no_grad():
393
+ if config.gate_rank_type == 1:
394
+ self.gate.weight.zero_()
395
+ elif config.gate_rank_type == 2:
396
+ self.g1.zero_()
397
+ ortho_init(self.g2, 0.1)
398
+
399
+ ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
400
+ ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
401
+
402
+ if self.config.use_tokenshift:
403
+ ddd = torch.ones(1, 1, n_embd)
404
+ for i in range(n_embd):
405
+ ddd[0, 0, i] = i / n_embd
406
+
407
+ ddd = torch.zeros(1, 1, n_embd)
408
+ self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
409
+ self.time_maa_r = nn.Parameter(torch.zeros_like(ddd))
410
+ self.time_maa_k = nn.Parameter(torch.zeros_like(ddd))
411
+ self.time_maa_v = nn.Parameter(torch.zeros_like(ddd))
412
+ self.time_maa_w = nn.Parameter(torch.zeros_like(ddd))
413
+ self.time_maa_g = nn.Parameter(torch.zeros_like(ddd))
414
+
415
+ lora_rank_tokenshift = config.lora_rank_tokenshift or (32 if n_embd < 4096 else 64)
416
+
417
+ self.time_maa_w2 = nn.Parameter(torch.zeros(5, lora_rank_tokenshift, n_embd).uniform_(-0.01, 0.01))
418
+ self.time_maa_w1 = nn.Parameter(torch.zeros(n_embd, lora_rank_tokenshift*self.time_maa_w2.size(0)))
419
+
420
+ lora_rank_decay = config.lora_rank_decay or (64 if n_embd < 4096 else 128)
421
+
422
+ # RWKV-6
423
+ decay_speed = torch.ones(dim_att)
424
+ for n in range(dim_att):
425
+ decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
426
+ self.time_decay = nn.Parameter(decay_speed.reshape(1,1,dim_att))
427
+ self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, lora_rank_decay))
428
+ self.time_decay_w2 = nn.Parameter(torch.zeros(lora_rank_decay, dim_att).uniform_(-0.01, 0.01))
429
+
430
+ def forward(
431
+ self,
432
+ hidden_states: torch.Tensor,
433
+ attention_mask: Optional[torch.Tensor] = None,
434
+ position_ids: Optional[torch.LongTensor] = None,
435
+ past_key_values: Optional[RWKV6State] = None,
436
+ output_attentions: bool = False,
437
+ use_cache: bool = False,
438
+ cache_position: Optional[torch.LongTensor] = None,
439
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
440
+ ):
441
+ output_shift_state = hidden_states[:, -1:].detach().clone()
442
+
443
+ bsz, q_len, hidden_dim = hidden_states.size()
444
+ H = self.num_heads
445
+
446
+ x = hidden_states
447
+
448
+ if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
449
+ input_kv_state, input_shift_state = past_key_values[self.layer_idx]
450
+ xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
451
+ else:
452
+ input_kv_state = None
453
+ xprev = F.pad(x, (0, 0, 1, -1))
454
+
455
+ if self.config.use_tokenshift:
456
+ dxprev = xprev - x
457
+
458
+ xxx = x + dxprev * self.time_maa_x
459
+ xxx = torch.tanh(xxx @ self.time_maa_w1).view(bsz*q_len, self.time_maa_w2.size(0), -1).transpose(0, 1)
460
+ xxx = torch.bmm(xxx, self.time_maa_w2).view(self.time_maa_w2.size(0), bsz, q_len, hidden_dim)
461
+
462
+ mr, mk, mv, mw, mg = xxx.unbind(dim=0)
463
+ xr = x + dxprev * (self.time_maa_r + mr)
464
+ xk = x + dxprev * (self.time_maa_k + mk)
465
+ xv = x + dxprev * (self.time_maa_v + mv)
466
+ xw = x + dxprev * (self.time_maa_w + mw)
467
+ xg = x + dxprev * (self.time_maa_g + mg)
468
+ else:
469
+ xr = xk = xv = xw = xg = x
470
+
471
+ query_states = self.q_proj(xr)
472
+ key_states = self.k_proj(xk)
473
+ value_states = self.v_proj(xv)
474
+ decay_states = (self.time_decay + torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2).to(query_states.dtype)
475
+ if self.config.gate_rank_type == 1:
476
+ gate_states = torch.sigmoid(self.gate(xg))
477
+ elif self.config.gate_rank_type == 2:
478
+ gate_states = torch.sigmoid(xg @ self.g1) @ self.g2
479
+
480
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
481
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
482
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
483
+ decay_states = decay_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
484
+
485
+ if position_embeddings is not None:
486
+ cos, sin = position_embeddings
487
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
488
+
489
+ # repeat k/v heads if n_kv_heads < n_heads
490
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
491
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
492
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
493
+
494
+ decay_states_log = -decay_states.float().exp()
495
+ decay_states_log = decay_states_log.clamp(-5) # FIXME - is this necessary?
496
+ if self.config.balance_state:
497
+ key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
498
+
499
+ # dealing with left-padding
500
+ if attention_mask is not None:
501
+ value_states = value_states * attention_mask[:, None, -value_states.shape[-2]:, None]
502
+
503
+ query_states = query_states.to(value_states.dtype)
504
+ key_states = key_states.to(value_states.dtype)
505
+
506
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
507
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
508
+ # cast them back in float16 just to be sure everything works as expected.
509
+ input_dtype = query_states.dtype
510
+ if input_dtype == torch.float32:
511
+ if torch.is_autocast_enabled():
512
+ target_dtype = torch.get_autocast_gpu_dtype()
513
+ # Handle the case where the model is quantized
514
+ elif hasattr(self.config, "_pre_quantization_dtype"):
515
+ target_dtype = self.config._pre_quantization_dtype
516
+ else:
517
+ target_dtype = self.q_proj.weight.dtype
518
+
519
+ logger.warning_once(
520
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
521
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
522
+ f" {target_dtype}."
523
+ )
524
+
525
+ query_states = query_states.to(target_dtype)
526
+ key_states = key_states.to(target_dtype)
527
+ value_states = value_states.to(target_dtype)
528
+
529
+ attn_weights = torch.empty(0, device=x.device)
530
+
531
+ scale = query_states.shape[-1] ** -0.5
532
+ output_final_state = not self.training and use_cache and past_key_values is not None
533
+ #attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
534
+ #attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
535
+ attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
536
+
537
+ if output_final_state:
538
+ past_key_values.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
539
+
540
+ attn_output = attn_output.transpose(1, 2).contiguous()
541
+ attn_output = attn_output.view(bsz, q_len, -1)
542
+ if self.config.groupnorm_att:
543
+ attn_output = self.ln_x(attn_output.view(bsz * q_len, -1)).view(bsz, q_len, -1)
544
+ if self.config.gate_rank_type != 0:
545
+ attn_output = attn_output * gate_states
546
+ attn_output = self.o_proj(attn_output)
547
+
548
+ return attn_output, attn_weights
549
+
550
+ class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
551
+ def __init__(self, config: RWKV6Qwen2Config, layer_idx: int):
552
+ nn.Module.__init__(self)
553
+ self.hidden_size = config.hidden_size
554
+
555
+ self.self_attn = RWKV6Attention(config, layer_idx) #QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
556
+
557
+ self.mlp = Qwen2MLP(config)
558
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
559
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
560
+
561
+ def forward(
562
+ self,
563
+ hidden_states: torch.Tensor,
564
+ attention_mask: Optional[torch.Tensor] = None,
565
+ position_ids: Optional[torch.LongTensor] = None,
566
+ past_key_values: Optional[Cache] = None,
567
+ output_attentions: Optional[bool] = False,
568
+ use_cache: Optional[bool] = False,
569
+ cache_position: Optional[torch.LongTensor] = None,
570
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
571
+ **kwargs,
572
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
573
+ residual = hidden_states
574
+
575
+ hidden_states = self.input_layernorm(hidden_states)
576
+
577
+ # Self Attention
578
+ hidden_states, self_attn_weights = self.self_attn(
579
+ hidden_states=hidden_states,
580
+ attention_mask=attention_mask,
581
+ position_ids=position_ids,
582
+ past_key_values=past_key_values,
583
+ output_attentions=output_attentions,
584
+ use_cache=use_cache,
585
+ cache_position=cache_position,
586
+ position_embeddings=position_embeddings,
587
+ **kwargs,
588
+ )
589
+ hidden_states = residual + hidden_states
590
+
591
+ # Fully Connected
592
+ residual = hidden_states
593
+ hidden_states = self.post_attention_layernorm(hidden_states)
594
+ hidden_states = self.mlp(hidden_states)
595
+ hidden_states = residual + hidden_states
596
+
597
+ outputs = (hidden_states,)
598
+ if output_attentions:
599
+ outputs += (self_attn_weights,)
600
+
601
+ return outputs
602
+
603
+ RWKV6QWEN2_START_DOCSTRING = r"""
604
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
605
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
606
+ etc.)
607
+
608
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
609
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
610
+ and behavior.
611
+
612
+ Parameters:
613
+ config ([`RWKV6Qwen2Config`]):
614
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
615
+ load the weights associated with the model, only the configuration. Check out the
616
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
617
+ """
618
+
619
+
620
+ @add_start_docstrings(
621
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
622
+ RWKV6QWEN2_START_DOCSTRING,
623
+ )
624
+ class RWKV6Qwen2PreTrainedModel(PreTrainedModel):
625
+ config_class = RWKV6Qwen2Config
626
+ base_model_prefix = "model"
627
+ supports_gradient_checkpointing = True
628
+ _no_split_modules = ["RWKV6Qwen2DecoderLayer"]
629
+ _skip_keys_device_placement = "past_key_values"
630
+ _supports_flash_attn_2 = True
631
+ _supports_sdpa = True
632
+ _supports_cache_class = True
633
+ _supports_quantized_cache = True
634
+ _supports_static_cache = True
635
+
636
+ def _init_weights(self, module):
637
+ std = self.config.initializer_range
638
+ if isinstance(module, nn.Linear):
639
+ module.weight.data.normal_(mean=0.0, std=std)
640
+ if module.bias is not None:
641
+ module.bias.data.zero_()
642
+ elif isinstance(module, nn.Embedding):
643
+ module.weight.data.normal_(mean=0.0, std=std)
644
+ if module.padding_idx is not None:
645
+ module.weight.data[module.padding_idx].zero_()
646
+
647
+
648
+ RWKV6QWEN2_INPUTS_DOCSTRING = r"""
649
+ Args:
650
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
651
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
652
+ it.
653
+
654
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
655
+ [`PreTrainedTokenizer.__call__`] for details.
656
+
657
+ [What are input IDs?](../glossary#input-ids)
658
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
659
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
660
+
661
+ - 1 for tokens that are **not masked**,
662
+ - 0 for tokens that are **masked**.
663
+
664
+ [What are attention masks?](../glossary#attention-mask)
665
+
666
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
667
+ [`PreTrainedTokenizer.__call__`] for details.
668
+
669
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
670
+ `past_key_values`).
671
+
672
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
673
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
674
+ information on the default strategy.
675
+
676
+ - 1 indicates the head is **not masked**,
677
+ - 0 indicates the head is **masked**.
678
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
679
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
680
+ config.n_positions - 1]`.
681
+
682
+ [What are position IDs?](../glossary#position-ids)
683
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
684
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
685
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
686
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
687
+
688
+ Two formats are allowed:
689
+ - a [`~cache_utils.Cache`] instance, see our
690
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
691
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
692
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
693
+ cache format.
694
+
695
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
696
+ legacy cache format will be returned.
697
+
698
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
699
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
700
+ of shape `(batch_size, sequence_length)`.
701
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
702
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
703
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
704
+ model's internal embedding lookup matrix.
705
+ use_cache (`bool`, *optional*):
706
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
707
+ `past_key_values`).
708
+ output_attentions (`bool`, *optional*):
709
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
710
+ tensors for more detail.
711
+ output_hidden_states (`bool`, *optional*):
712
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
713
+ more detail.
714
+ return_dict (`bool`, *optional*):
715
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
716
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
717
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
718
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
719
+ the complete sequence length.
720
+ """
721
+
722
+ @add_start_docstrings(
723
+ "The bare RWKV6Qwen2 Model outputting raw hidden-states without any specific head on top.",
724
+ RWKV6QWEN2_START_DOCSTRING,
725
+ )
726
+ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
727
+ """
728
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
729
+
730
+ Args:
731
+ config: RWKV6Qwen2Config
732
+ """
733
+
734
+ def __init__(self, config: RWKV6Qwen2Config):
735
+ super().__init__(config)
736
+ self.padding_idx = config.pad_token_id
737
+ self.vocab_size = config.vocab_size
738
+
739
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
740
+ self.layers = nn.ModuleList(
741
+ [RWKV6Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
742
+ )
743
+ self._attn_implementation = config._attn_implementation
744
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
745
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
746
+
747
+ self.gradient_checkpointing = False
748
+ # Initialize weights and apply final processing
749
+ self.post_init()
750
+
751
+ def get_input_embeddings(self):
752
+ return self.embed_tokens
753
+
754
+ def set_input_embeddings(self, value):
755
+ self.embed_tokens = value
756
+
757
+ @add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
758
+ def forward(
759
+ self,
760
+ input_ids: torch.LongTensor = None,
761
+ attention_mask: Optional[torch.Tensor] = None,
762
+ position_ids: Optional[torch.LongTensor] = None,
763
+ past_key_values: Optional[Cache] = None,
764
+ inputs_embeds: Optional[torch.FloatTensor] = None,
765
+ use_cache: Optional[bool] = None,
766
+ output_attentions: Optional[bool] = None,
767
+ output_hidden_states: Optional[bool] = None,
768
+ return_dict: Optional[bool] = None,
769
+ cache_position: Optional[torch.LongTensor] = None,
770
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
771
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
772
+ output_hidden_states = (
773
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
774
+ )
775
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
776
+
777
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
778
+
779
+ if (input_ids is None) ^ (inputs_embeds is not None):
780
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
781
+
782
+ if self.gradient_checkpointing and self.training:
783
+ if use_cache:
784
+ logger.warning_once(
785
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
786
+ )
787
+ use_cache = False
788
+
789
+ # kept for BC (non `Cache` `past_key_values` inputs)
790
+ #return_legacy_cache = False
791
+ if use_cache and not isinstance(past_key_values, RWKV6State):
792
+ #return_legacy_cache = True
793
+ past_key_values = RWKV6State()
794
+ # if past_key_values is None:
795
+ # past_key_values = DynamicCache()
796
+ # else:
797
+ # past_key_values = DynamicCache.from_legacy_cache(past_key_values)
798
+ # logger.warning_once(
799
+ # "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
800
+ # "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
801
+ # "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
802
+ # )
803
+
804
+ if inputs_embeds is None:
805
+ inputs_embeds = self.embed_tokens(input_ids)
806
+
807
+ if cache_position is None:
808
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
809
+ cache_position = torch.arange(
810
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
811
+ )
812
+
813
+ if position_ids is None:
814
+ position_ids = cache_position.unsqueeze(0)
815
+
816
+ # causal_mask = self._update_causal_mask(
817
+ # attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
818
+ # )
819
+
820
+ causal_mask = None
821
+
822
+ hidden_states = inputs_embeds
823
+
824
+ # create position embeddings to be shared across the decoder layers
825
+ position_embeddings = None
826
+ if self.config.use_rope:
827
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
828
+
829
+ # decoder layers
830
+ all_hidden_states = () if output_hidden_states else None
831
+ all_self_attns = () if output_attentions else None
832
+ next_decoder_cache = None
833
+
834
+ for decoder_layer in self.layers:
835
+ if output_hidden_states:
836
+ all_hidden_states += (hidden_states,)
837
+
838
+ if self.gradient_checkpointing and self.training:
839
+ layer_outputs = self._gradient_checkpointing_func(
840
+ decoder_layer.__call__,
841
+ hidden_states,
842
+ causal_mask,
843
+ position_ids,
844
+ past_key_values,
845
+ output_attentions,
846
+ use_cache,
847
+ cache_position,
848
+ position_embeddings,
849
+ )
850
+ else:
851
+ layer_outputs = decoder_layer(
852
+ hidden_states,
853
+ attention_mask=attention_mask,
854
+ position_ids=position_ids,
855
+ past_key_values=past_key_values,
856
+ output_attentions=output_attentions,
857
+ use_cache=use_cache,
858
+ cache_position=cache_position,
859
+ position_embeddings=position_embeddings,
860
+ )
861
+
862
+ hidden_states = layer_outputs[0]
863
+
864
+ if output_attentions:
865
+ all_self_attns += (layer_outputs[1],)
866
+
867
+ hidden_states = self.norm(hidden_states)
868
+
869
+ # add hidden states from the last decoder layer
870
+ if output_hidden_states:
871
+ all_hidden_states += (hidden_states,)
872
+
873
+ #if return_legacy_cache:
874
+ # next_cache = next_cache.to_legacy_cache()
875
+
876
+ if not return_dict:
877
+ return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
878
+ return BaseModelOutputWithPast(
879
+ last_hidden_state=hidden_states,
880
+ past_key_values=past_key_values,
881
+ hidden_states=all_hidden_states,
882
+ attentions=all_self_attns,
883
+ )
884
+
885
+ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
886
+ _tied_weights_keys = ["lm_head.weight"]
887
+
888
+ def __init__(self, config):
889
+ super().__init__(config)
890
+ self.model = RWKV6Qwen2Model(config)
891
+ self.vocab_size = config.vocab_size
892
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
893
+
894
+ # Initialize weights and apply final processing
895
+ self.post_init()
896
+
897
+ def get_input_embeddings(self):
898
+ return self.model.embed_tokens
899
+
900
+ def set_input_embeddings(self, value):
901
+ self.model.embed_tokens = value
902
+
903
+ def get_output_embeddings(self):
904
+ return self.lm_head
905
+
906
+ def set_output_embeddings(self, new_embeddings):
907
+ self.lm_head = new_embeddings
908
+
909
+ def set_decoder(self, decoder):
910
+ self.model = decoder
911
+
912
+ def get_decoder(self):
913
+ return self.model
914
+
915
+ @add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
916
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
917
+ def forward(
918
+ self,
919
+ input_ids: torch.LongTensor = None,
920
+ attention_mask: Optional[torch.Tensor] = None,
921
+ position_ids: Optional[torch.LongTensor] = None,
922
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
923
+ inputs_embeds: Optional[torch.FloatTensor] = None,
924
+ labels: Optional[torch.LongTensor] = None,
925
+ use_cache: Optional[bool] = None,
926
+ output_attentions: Optional[bool] = None,
927
+ output_hidden_states: Optional[bool] = None,
928
+ return_dict: Optional[bool] = None,
929
+ cache_position: Optional[torch.LongTensor] = None,
930
+ num_logits_to_keep: int = 0,
931
+ **loss_kwargs,
932
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
933
+ r"""
934
+ Args:
935
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
936
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
937
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
938
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
939
+
940
+ num_logits_to_keep (`int`, *optional*):
941
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
942
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
943
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
944
+
945
+ Returns:
946
+
947
+ Example:
948
+
949
+ ```python
950
+ >>> from transformers import AutoTokenizer, RWKV6Qwen2ForCausalLM
951
+
952
+ >>> model = RWKV6Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
953
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
954
+
955
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
956
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
957
+
958
+ >>> # Generate
959
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
960
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
961
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
962
+ ```"""
963
+
964
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
965
+ output_hidden_states = (
966
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
967
+ )
968
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
969
+
970
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
971
+ outputs = self.model(
972
+ input_ids=input_ids,
973
+ attention_mask=attention_mask,
974
+ position_ids=position_ids,
975
+ past_key_values=past_key_values,
976
+ inputs_embeds=inputs_embeds,
977
+ use_cache=use_cache,
978
+ output_attentions=output_attentions,
979
+ output_hidden_states=output_hidden_states,
980
+ return_dict=return_dict,
981
+ cache_position=cache_position,
982
+ )
983
+
984
+ hidden_states = outputs[0]
985
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
986
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
987
+
988
+ loss = None
989
+ if labels is not None:
990
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
991
+
992
+ if not return_dict:
993
+ output = (logits,) + outputs[1:]
994
+ return (loss,) + output if loss is not None else output
995
+
996
+ return CausalLMOutputWithPast(
997
+ loss=loss,
998
+ logits=logits,
999
+ past_key_values=outputs.past_key_values,
1000
+ hidden_states=outputs.hidden_states,
1001
+ attentions=outputs.attentions,
1002
+ )
1003
+
1004
+ def prepare_inputs_for_generation(
1005
+ self,
1006
+ input_ids: torch.LongTensor,
1007
+ past_key_values: Optional[Cache] = None,
1008
+ attention_mask: Optional[torch.LongTensor] = None,
1009
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1010
+ cache_position: Optional[torch.LongTensor] = None,
1011
+ **kwargs,
1012
+ ):
1013
+ # only last token for `inputs_ids` if the `past_key_values` is not empty.
1014
+ if past_key_values is not None and len(past_key_values) > 0:
1015
+ input_ids = input_ids[:, -1:]
1016
+
1017
+ model_inputs = {
1018
+ 'past_key_values': past_key_values,
1019
+ 'attention_mask': attention_mask,
1020
+ 'cache_position': cache_position,
1021
+ }
1022
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1023
+ if inputs_embeds is not None and past_key_values is None:
1024
+ model_inputs['inputs_embeds'] = inputs_embeds
1025
+ else:
1026
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1027
+ # recompiles graphs as the stride of the inputs is a guard.
1028
+ # Ref: https://github.com/huggingface/transformers/pull/29114
1029
+ # TODO: use `next_tokens` directly instead.
1030
+ model_inputs['input_ids'] = input_ids.contiguous()
1031
+
1032
+ model_inputs.update(**kwargs)
1033
+
1034
+ # 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
1035
+ model_inputs.pop("labels", None)
1036
+
1037
+ return model_inputs
1038
+
1039
+ @add_start_docstrings(
1040
+ """
1041
+ The RWKV6Qwen2 Model transformer with a sequence classification head on top (linear layer).
1042
+
1043
+ [`RWKV6Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1044
+ (e.g. GPT-2) do.
1045
+
1046
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1047
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1048
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1049
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1050
+ each row of the batch).
1051
+ """,
1052
+ RWKV6QWEN2_START_DOCSTRING,
1053
+ )
1054
+ class RWKV6Qwen2ForSequenceClassification(RWKV6Qwen2PreTrainedModel):
1055
+ def __init__(self, config):
1056
+ super().__init__(config)
1057
+ self.num_labels = config.num_labels
1058
+ self.model = RWKV6Qwen2Model(config)
1059
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1060
+
1061
+ # Initialize weights and apply final processing
1062
+ self.post_init()
1063
+
1064
+ def get_input_embeddings(self):
1065
+ return self.model.embed_tokens
1066
+
1067
+ def set_input_embeddings(self, value):
1068
+ self.model.embed_tokens = value
1069
+
1070
+ @add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
1071
+ def forward(
1072
+ self,
1073
+ input_ids: torch.LongTensor = None,
1074
+ attention_mask: Optional[torch.Tensor] = None,
1075
+ position_ids: Optional[torch.LongTensor] = None,
1076
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1077
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1078
+ labels: Optional[torch.LongTensor] = None,
1079
+ use_cache: Optional[bool] = None,
1080
+ output_attentions: Optional[bool] = None,
1081
+ output_hidden_states: Optional[bool] = None,
1082
+ return_dict: Optional[bool] = None,
1083
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1084
+ r"""
1085
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1086
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1087
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1088
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1089
+ """
1090
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1091
+
1092
+ transformer_outputs = self.model(
1093
+ input_ids,
1094
+ attention_mask=attention_mask,
1095
+ position_ids=position_ids,
1096
+ past_key_values=past_key_values,
1097
+ inputs_embeds=inputs_embeds,
1098
+ use_cache=use_cache,
1099
+ output_attentions=output_attentions,
1100
+ output_hidden_states=output_hidden_states,
1101
+ return_dict=return_dict,
1102
+ )
1103
+ hidden_states = transformer_outputs[0]
1104
+ logits = self.score(hidden_states)
1105
+
1106
+ if input_ids is not None:
1107
+ batch_size = input_ids.shape[0]
1108
+ else:
1109
+ batch_size = inputs_embeds.shape[0]
1110
+
1111
+ if self.config.pad_token_id is None and batch_size != 1:
1112
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1113
+ if self.config.pad_token_id is None:
1114
+ sequence_lengths = -1
1115
+ else:
1116
+ if input_ids is not None:
1117
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1118
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1119
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1120
+ sequence_lengths = sequence_lengths.to(logits.device)
1121
+ else:
1122
+ sequence_lengths = -1
1123
+
1124
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1125
+
1126
+ loss = None
1127
+ if labels is not None:
1128
+ labels = labels.to(logits.device)
1129
+ if self.config.problem_type is None:
1130
+ if self.num_labels == 1:
1131
+ self.config.problem_type = "regression"
1132
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1133
+ self.config.problem_type = "single_label_classification"
1134
+ else:
1135
+ self.config.problem_type = "multi_label_classification"
1136
+
1137
+ if self.config.problem_type == "regression":
1138
+ loss_fct = MSELoss()
1139
+ if self.num_labels == 1:
1140
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1141
+ else:
1142
+ loss = loss_fct(pooled_logits, labels)
1143
+ elif self.config.problem_type == "single_label_classification":
1144
+ loss_fct = CrossEntropyLoss()
1145
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1146
+ elif self.config.problem_type == "multi_label_classification":
1147
+ loss_fct = BCEWithLogitsLoss()
1148
+ loss = loss_fct(pooled_logits, labels)
1149
+ if not return_dict:
1150
+ output = (pooled_logits,) + transformer_outputs[1:]
1151
+ return ((loss,) + output) if loss is not None else output
1152
+
1153
+ return SequenceClassifierOutputWithPast(
1154
+ loss=loss,
1155
+ logits=pooled_logits,
1156
+ past_key_values=transformer_outputs.past_key_values,
1157
+ hidden_states=transformer_outputs.hidden_states,
1158
+ attentions=transformer_outputs.attentions,
1159
+ )
1160
+
1161
+
1162
+ @add_start_docstrings(
1163
+ """
1164
+ The RWKV6Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
1165
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
1166
+ """,
1167
+ RWKV6QWEN2_START_DOCSTRING,
1168
+ )
1169
+ # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->RWKV6Qwen2, LLAMA->RWKV6QWEN2
1170
+ class RWKV6Qwen2ForTokenClassification(RWKV6Qwen2PreTrainedModel):
1171
+ def __init__(self, config):
1172
+ super().__init__(config)
1173
+ self.num_labels = config.num_labels
1174
+ self.model = RWKV6Qwen2Model(config)
1175
+ if getattr(config, "classifier_dropout", None) is not None:
1176
+ classifier_dropout = config.classifier_dropout
1177
+ elif getattr(config, "hidden_dropout", None) is not None:
1178
+ classifier_dropout = config.hidden_dropout
1179
+ else:
1180
+ classifier_dropout = 0.1
1181
+ self.dropout = nn.Dropout(classifier_dropout)
1182
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
1183
+
1184
+ # Initialize weights and apply final processing
1185
+ self.post_init()
1186
+
1187
+ def get_input_embeddings(self):
1188
+ return self.model.embed_tokens
1189
+
1190
+ def set_input_embeddings(self, value):
1191
+ self.model.embed_tokens = value
1192
+
1193
+ @add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
1194
+ @add_code_sample_docstrings(
1195
+ checkpoint=_CHECKPOINT_FOR_DOC,
1196
+ output_type=TokenClassifierOutput,
1197
+ config_class=_CONFIG_FOR_DOC,
1198
+ )
1199
+ def forward(
1200
+ self,
1201
+ input_ids: Optional[torch.LongTensor] = None,
1202
+ attention_mask: Optional[torch.Tensor] = None,
1203
+ position_ids: Optional[torch.LongTensor] = None,
1204
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1205
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1206
+ labels: Optional[torch.LongTensor] = None,
1207
+ use_cache: Optional[bool] = None,
1208
+ output_attentions: Optional[bool] = None,
1209
+ output_hidden_states: Optional[bool] = None,
1210
+ return_dict: Optional[bool] = None,
1211
+ ) -> Union[Tuple, TokenClassifierOutput]:
1212
+ r"""
1213
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1214
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1215
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1216
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1217
+ """
1218
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1219
+
1220
+ outputs = self.model(
1221
+ input_ids,
1222
+ attention_mask=attention_mask,
1223
+ position_ids=position_ids,
1224
+ past_key_values=past_key_values,
1225
+ inputs_embeds=inputs_embeds,
1226
+ use_cache=use_cache,
1227
+ output_attentions=output_attentions,
1228
+ output_hidden_states=output_hidden_states,
1229
+ return_dict=return_dict,
1230
+ )
1231
+ sequence_output = outputs[0]
1232
+ sequence_output = self.dropout(sequence_output)
1233
+ logits = self.score(sequence_output)
1234
+
1235
+ loss = None
1236
+ if labels is not None:
1237
+ loss = self.loss_function(logits, labels, self.config)
1238
+
1239
+ if not return_dict:
1240
+ output = (logits,) + outputs[2:]
1241
+ return ((loss,) + output) if loss is not None else output
1242
+
1243
+ return TokenClassifierOutput(
1244
+ loss=loss,
1245
+ logits=logits,
1246
+ hidden_states=outputs.hidden_states,
1247
+ attentions=outputs.attentions,
1248
+ )
1249
+
1250
+
1251
+ @add_start_docstrings(
1252
+ """
1253
+ The RWKV6Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like
1254
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1255
+ """,
1256
+ RWKV6QWEN2_START_DOCSTRING,
1257
+ )
1258
+ # Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->RWKV6Qwen2, MISTRAL->RWKV6QWEN2
1259
+ class RWKV6Qwen2ForQuestionAnswering(RWKV6Qwen2PreTrainedModel):
1260
+ base_model_prefix = "model"
1261
+
1262
+ # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->RWKV6Qwen2
1263
+ def __init__(self, config):
1264
+ super().__init__(config)
1265
+ self.model = RWKV6Qwen2Model(config)
1266
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1267
+
1268
+ # Initialize weights and apply final processing
1269
+ self.post_init()
1270
+
1271
+ def get_input_embeddings(self):
1272
+ return self.model.embed_tokens
1273
+
1274
+ def set_input_embeddings(self, value):
1275
+ self.model.embed_tokens = value
1276
+
1277
+ @add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
1278
+ def forward(
1279
+ self,
1280
+ input_ids: Optional[torch.LongTensor] = None,
1281
+ attention_mask: Optional[torch.FloatTensor] = None,
1282
+ position_ids: Optional[torch.LongTensor] = None,
1283
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1284
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1285
+ start_positions: Optional[torch.LongTensor] = None,
1286
+ end_positions: Optional[torch.LongTensor] = None,
1287
+ output_attentions: Optional[bool] = None,
1288
+ output_hidden_states: Optional[bool] = None,
1289
+ return_dict: Optional[bool] = None,
1290
+ **kwargs,
1291
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1292
+ r"""
1293
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1294
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1295
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1296
+ are not taken into account for computing the loss.
1297
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1298
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1299
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1300
+ are not taken into account for computing the loss.
1301
+ """
1302
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1303
+
1304
+ outputs = self.model(
1305
+ input_ids,
1306
+ attention_mask=attention_mask,
1307
+ position_ids=position_ids,
1308
+ past_key_values=past_key_values,
1309
+ inputs_embeds=inputs_embeds,
1310
+ output_attentions=output_attentions,
1311
+ output_hidden_states=output_hidden_states,
1312
+ return_dict=return_dict,
1313
+ )
1314
+
1315
+ sequence_output = outputs[0]
1316
+
1317
+ logits = self.qa_outputs(sequence_output)
1318
+ start_logits, end_logits = logits.split(1, dim=-1)
1319
+ start_logits = start_logits.squeeze(-1).contiguous()
1320
+ end_logits = end_logits.squeeze(-1).contiguous()
1321
+
1322
+ loss = None
1323
+ if start_positions is not None and end_positions is not None:
1324
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1325
+
1326
+ if not return_dict:
1327
+ output = (start_logits, end_logits) + outputs[2:]
1328
+ return ((loss,) + output) if loss is not None else output
1329
+
1330
+ return QuestionAnsweringModelOutput(
1331
+ loss=loss,
1332
+ start_logits=start_logits,
1333
+ end_logits=end_logits,
1334
+ hidden_states=outputs.hidden_states,
1335
+ attentions=outputs.attentions,
1336
+ )
qwen2.py ADDED
@@ -0,0 +1,670 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, math, gc, importlib.util
2
+ import torch
3
+ import torch.utils.checkpoint
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from torch import Tensor
7
+ from typing import Tuple, Optional
8
+
9
+ #from src.state import ModelState, BlockState, ChannelMixState, TimeMixState, Shared
10
+
11
+ #from configs import TrainerCLI_Config, Model_Config, Transformer_Config, Train_Config
12
+ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
13
+
14
+ #from src.rotary import generate_rotary_embedding, generate_binary_rotary_embedding, apply_rotary_embedding
15
+
16
+ #from src.CoreDependencies import *
17
+
18
+ from dataclasses import dataclass
19
+
20
+ import torch.utils.checkpoint
21
+ if importlib.util.find_spec('deepspeed'):
22
+ import deepspeed
23
+
24
+ from logger import print0 as print
25
+
26
+ from fla.ops.gla.chunk import chunk_gla
27
+ from fla.ops.gla.fused_recurrent import fused_recurrent_gla
28
+
29
+ class ModelState:
30
+ def __init__(self):
31
+ self.seq_pos = 0
32
+ self.input_tokens_cache = torch.tensor([])
33
+ self.k_cache = torch.tensor([])
34
+ self.block_states:list[BlockState] = []
35
+
36
+ class TimeMixState:
37
+ def __init__(self, wkv_state=torch.tensor([]), shift_state=torch.tensor([])):
38
+ self.wkv_state = wkv_state
39
+ self.shift_state = shift_state
40
+
41
+ class ChannelMixState:
42
+ def __init__(self, shift_state=torch.tensor([])):
43
+ self.shift_state = shift_state
44
+
45
+ class BlockState:
46
+ def __init__(self, time_mix_state: TimeMixState, channel_mix_state: ChannelMixState):
47
+ self.time_mix_state = time_mix_state
48
+ self.channel_mix_state = channel_mix_state
49
+
50
+ class Shared:
51
+ def __init__(self):
52
+ self.angles = torch.tensor([])
53
+ self.bias_mask = torch.tensor([])
54
+
55
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
56
+ class Qwen2RMSNorm(nn.Module):
57
+ def __init__(self, hidden_size, eps=1e-6):
58
+ """
59
+ Qwen2RMSNorm is equivalent to T5LayerNorm
60
+ """
61
+ super().__init__()
62
+ self.weight = nn.Parameter(torch.ones(hidden_size))
63
+ self.variance_epsilon = eps
64
+
65
+ def forward(self, hidden_states):
66
+ input_dtype = hidden_states.dtype
67
+ hidden_states = hidden_states.to(torch.float32)
68
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
69
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
70
+ return self.weight * hidden_states.to(input_dtype)
71
+
72
+ def extra_repr(self):
73
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
74
+
75
+ def generate_rotary_embedding(max_seqlen:int, dim:int, theta:float = 10000.0, scale:float = 1):
76
+ #inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float).to(device) / dim))
77
+
78
+ angular_velocity = theta ** -(torch.arange(0, dim, 2, dtype=torch.float) / dim) / scale # frequencies from 1.0 ... 1/theta
79
+ angles = torch.outer(torch.arange(max_seqlen), angular_velocity)
80
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
81
+ emb = torch.cat((angles, angles), dim=-1)
82
+ return torch.stack([emb.cos(), emb.sin()], dim=0)
83
+ #return torch.polar(torch.ones_like(angles), angles)
84
+
85
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
86
+ def rotate_half(x):
87
+ """Rotates half the hidden dims of the input."""
88
+ x1 = x[..., : x.shape[-1] // 2]
89
+ x2 = x[..., x.shape[-1] // 2 :]
90
+ return torch.cat((-x2, x1), dim=-1)
91
+
92
+
93
+ # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
94
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim:int=1):
95
+ B, L = q.size(0), q.size(-2)
96
+ cos = cos[:L].unsqueeze(0).expand(B,L,-1).unsqueeze(unsqueeze_dim)
97
+ sin = sin[:L].unsqueeze(0).expand(B,L,-1).unsqueeze(unsqueeze_dim)
98
+ q_embed = (q * cos) + (rotate_half(q) * sin)
99
+ k_embed = (k * cos) + (rotate_half(k) * sin)
100
+ return q_embed, k_embed
101
+
102
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
103
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
104
+ """
105
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
106
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
107
+ """
108
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
109
+ if n_rep == 1:
110
+ return hidden_states
111
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
112
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
113
+
114
+ def get_tmix_default_state(x:Tensor, config:Qwen2Config, requires_grad:bool):
115
+ B, T, C = x.size()
116
+ return TimeMixState(
117
+ torch.zeros([B, config.num_attention_heads, config.hidden_size // config.num_attention_heads, config.hidden_size // config.num_attention_heads], dtype=x.dtype, device=x.device, requires_grad=requires_grad),
118
+ torch.zeros([B, C], dtype=x.dtype, device=x.device, requires_grad=requires_grad)
119
+ )
120
+
121
+ @dataclass
122
+ class LLMOutput:
123
+ logits: torch.FloatTensor = None
124
+ model_state: ModelState = None
125
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
126
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
127
+ post_attention_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
128
+ student_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
129
+ student_post_attention_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
130
+
131
+ class TMix_qwen2(nn.Module):
132
+ def get_default_state_factory(self): return get_tmix_default_state
133
+
134
+ def __init__(self, config:Qwen2Config, layer_id):
135
+ super().__init__()
136
+ self.config = config
137
+ self.layer_id = layer_id
138
+ self.ctx_len = config.max_position_embeddings
139
+
140
+ self.head_dim = config.hidden_size // config.num_attention_heads
141
+
142
+ self.hidden_size = config.hidden_size
143
+ self.num_heads = config.num_attention_heads
144
+ self.num_key_value_heads = config.num_key_value_heads if config.num_key_value_heads > 0 else self.num_heads
145
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
146
+ # self.max_position_embeddings = config.max_position_embeddings
147
+ # self.rope_theta = config.rope_theta
148
+ # self.is_causal = True
149
+ # self.attention_dropout = config.attention_dropout
150
+
151
+ if (self.head_dim * self.num_heads) != self.hidden_size:
152
+ raise ValueError(
153
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
154
+ f" and `num_heads`: {self.num_heads})."
155
+ )
156
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
157
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
158
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
159
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
160
+
161
+ # self.rotary_emb = Qwen2RotaryEmbedding(
162
+ # self.head_dim,
163
+ # max_position_embeddings=config.rope.max_seqlen,
164
+ # base=config.rope.base,
165
+ # )
166
+
167
+ def forward(self, x, last_model_state:ModelState, shared:Shared, output_attentions:bool=False):
168
+ last_state = last_model_state.block_states[self.layer_id].time_mix_state
169
+ B, L, D = x.size()
170
+ QH = self.num_heads
171
+ KVH = self.num_key_value_heads
172
+
173
+ q = self.q_proj(x)
174
+ k = self.k_proj(x)
175
+ v = self.v_proj(x)
176
+
177
+ wkv_state = last_state.wkv_state
178
+
179
+ # handle recurrent inference via maintaining a kv cache
180
+ # if not self.training:
181
+ # new_kv_cache = torch.stack([k, v], dim=0)
182
+ # wkv_state = torch.cat([wkv_state, new_kv_cache], dim=-2)
183
+ # k, v = wkv_state.unbind(0)
184
+ # k, v = k.contiguous(), v.contiguous()
185
+
186
+ is_causal = q.size(1)==k.size(1)
187
+
188
+ q = q.view(B,L,QH,-1).transpose(1,2)
189
+ k = k.view(B,L,KVH,-1).transpose(1,2)
190
+ v = v.view(B,L,KVH,-1).transpose(1,2)
191
+
192
+ #q, k = apply_rotary_embedding(q, k, shared.angles)
193
+ #kv_seq_len, position_ids = L, torch.arange(L, dtype=torch.int, device=v.device).view(1, L).expand(B, L)
194
+ #cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)
195
+ cos, sin = shared.angles.unbind(0)
196
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
197
+ q = q.to(v.dtype)
198
+ k = k.to(v.dtype)
199
+
200
+ # repeat k/v heads if n_kv_heads < n_heads
201
+ k = repeat_kv(k, self.num_key_value_groups)
202
+ v = repeat_kv(v, self.num_key_value_groups)
203
+
204
+ if output_attentions:
205
+ attn_weights = (q * (self.head_dim ** -0.5)) @ k.mT
206
+
207
+ #y = nn.functional.softmax(attn_weights + causal_mask, dim=-1, dtype=torch.float32).to(q.dtype) @ v
208
+ #attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
209
+ #y = torch.matmul(attn_weights, v)
210
+
211
+ # NOTE - we are outputting the non-softmaxed attention weights, just with exp() maxed to 1.0 since we're comparing against pre-normalized output of linear attention
212
+ # upcast attention to fp32
213
+ causal_mask = torch.full([L, L], fill_value=-torch.inf, device=attn_weights.device, dtype=attn_weights.dtype).triu(1)
214
+
215
+ attn_weights = nn.functional.softmax(attn_weights + causal_mask, dim=-1, dtype=torch.float32).to(q.dtype)
216
+
217
+ #attn_weights = attn_weights.tril()
218
+ #attn_weights = (attn_weights - attn_weights.max() + causal_mask).exp()
219
+ #attn_weights = (attn_weights - torch.max(attn_weights, dim=-1, keepdim=True).values + causal_mask).exp()
220
+ else:
221
+ attn_weights = torch.empty(0, device=x.device)
222
+
223
+ y = nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
224
+ y = y.transpose(1,2).reshape(B,L,D)
225
+ y = self.o_proj(y)
226
+ return y, TimeMixState(wkv_state, last_state.shift_state), attn_weights
227
+
228
+ class TMix_qwen2rwkv(TMix_qwen2):
229
+ """
230
+ Qwen2 RWKV-6cSimple attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
231
+ and adds RWKV specific weights for tokenshift, decay, time_first, and the final layernorm.
232
+ """
233
+
234
+ def __init__(self, config:Qwen2Config, layer_id):
235
+ super().__init__(config, layer_id)
236
+
237
+ n_layer = config.num_hidden_layers
238
+ n_embd = self.hidden_size
239
+ dim_att = self.num_heads * self.head_dim
240
+ layer_id = self.layer_id
241
+
242
+ with torch.no_grad():
243
+ ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
244
+ ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
245
+ ddd = torch.ones(1, 1, n_embd)
246
+ for i in range(n_embd):
247
+ ddd[0, 0, i] = i / n_embd
248
+
249
+ # self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
250
+ # self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
251
+ # self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
252
+ # self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1))
253
+ # self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
254
+
255
+ ddd = torch.zeros(1, 1, n_embd)
256
+ self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
257
+ self.time_maa_r = nn.Parameter(torch.zeros_like(ddd))
258
+ self.time_maa_k = nn.Parameter(torch.zeros_like(ddd))
259
+ self.time_maa_v = nn.Parameter(torch.zeros_like(ddd))
260
+ self.time_maa_w = nn.Parameter(torch.zeros_like(ddd))
261
+ self.time_maa_g = nn.Parameter(torch.zeros_like(ddd))
262
+
263
+ D_MIX_LORA = 32 if n_embd < 4096 else 64
264
+ self.time_maa_w2 = nn.Parameter(torch.zeros(5, D_MIX_LORA, n_embd).uniform_(-0.01, 0.01))
265
+ self.time_maa_w1 = nn.Parameter(torch.zeros(n_embd, D_MIX_LORA*self.time_maa_w2.size(0)))
266
+
267
+ # # per-head RWKV-6
268
+ # H = self.num_heads
269
+ # # fancy time_decay
270
+ # decay_speed = torch.ones(H)
271
+ # for h in range(H):
272
+ # decay_speed[h] = -6 + 5 * (h / max(H - 1, 1)) ** (0.7 + 1.3 * ratio_0_to_1)
273
+ # self.time_decay = nn.Parameter(decay_speed)
274
+ # #self.time_decay = nn.Parameter(torch.empty(H)).uniform_(-8, -7)
275
+ # D_DECAY_LORA = 64 if n_embd < 4096 else 128
276
+ # self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, D_DECAY_LORA))
277
+ # self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, H).uniform_(-0.01, 0.01))
278
+
279
+ # RWKV-6
280
+ decay_speed = torch.ones(dim_att)
281
+ for n in range(dim_att):
282
+ decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
283
+ self.time_decay = nn.Parameter(decay_speed.reshape(1,1,dim_att))
284
+ D_DECAY_LORA = 64 if n_embd < 4096 else 128
285
+ self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, D_DECAY_LORA))
286
+ self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, dim_att).uniform_(-0.01, 0.01))
287
+ # tmp = torch.zeros(dim_att)
288
+ # for n in range(dim_att):
289
+ # zigzag = ((n + 1) % 3 - 1) * 0.1
290
+ # tmp[n] = ratio_0_to_1 * (1 - (n / (dim_att - 1))) + zigzag
291
+ # self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
292
+
293
+ self.gate = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
294
+ # start gate out with no effect
295
+ nn.init.zeros_(self.gate.weight)
296
+ #nn.init.ones_(self.gate.bias)
297
+
298
+ #self.ln_x = nn.LayerNorm(dim_att)
299
+
300
+ def segsum(self, w_log): # B H L 1
301
+ w_log_cumsum = torch.cumsum(w_log, dim=-2) # (B, H, L, 1)
302
+ w_mask = torch.exp((w_log_cumsum - w_log_cumsum.mT).tril()).tril() # (B, H, L, L)
303
+ return w_mask
304
+
305
+ def forward(self, x, last_model_state:ModelState, shared:Shared, output_attentions:bool=False):
306
+ last_state = last_model_state.block_states[self.layer_id].time_mix_state
307
+ bsz, q_len, hidden_dim = x.size()
308
+
309
+ dxprev = torch.nn.functional.pad(x, (0, 0, 1, -1)) - x
310
+
311
+ xxx = x + dxprev * self.time_maa_x
312
+ xxx = torch.tanh(xxx @ self.time_maa_w1).view(bsz*q_len, self.time_maa_w2.size(0), -1).transpose(0, 1)
313
+ xxx = torch.bmm(xxx, self.time_maa_w2).view(self.time_maa_w2.size(0), bsz, q_len, hidden_dim)
314
+
315
+ mr, mk, mv, mw, mg = xxx.unbind(dim=0)
316
+ xr = x + dxprev * (self.time_maa_r + mr)
317
+ xk = x + dxprev * (self.time_maa_k + mk)
318
+ xv = x + dxprev * (self.time_maa_v + mv)
319
+ xw = x + dxprev * (self.time_maa_w + mw)
320
+ xg = x + dxprev * (self.time_maa_g + mg)
321
+
322
+ query_states = self.q_proj(xr)
323
+ key_states = self.k_proj(xk)
324
+ value_states = self.v_proj(xv)
325
+ decay_states = (self.time_decay + torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2).to(query_states.dtype)
326
+ gate_states = torch.sigmoid(self.gate(xg))
327
+
328
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
329
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
330
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
331
+ decay_states = decay_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
332
+
333
+ # repeat k/v heads if n_kv_heads < n_heads
334
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
335
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
336
+ #dropout_rate = 0.0 if not self.training else self.attention_dropout
337
+
338
+ decay_states_log = -decay_states.float().exp()
339
+ decay_states_log = decay_states_log.clamp(-5) # FIXME - is this necessary?
340
+ key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
341
+
342
+ query_states = query_states.to(value_states.dtype)
343
+ key_states = key_states.to(value_states.dtype)
344
+
345
+ # # In PEFT, usually we cast the layer norms in float32 for training stability reasons
346
+ # # therefore the input hidden states gets silently casted in float32. Hence, we need
347
+ # # cast them back in float16 just to be sure everything works as expected.
348
+ # input_dtype = query_states.dtype
349
+ # if input_dtype == torch.float32:
350
+ # if torch.is_autocast_enabled():
351
+ # target_dtype = torch.get_autocast_gpu_dtype()
352
+ # # Handle the case where the model is quantized
353
+ # elif hasattr(self.config, "_pre_quantization_dtype"):
354
+ # target_dtype = self.config._pre_quantization_dtype
355
+ # else:
356
+ # target_dtype = self.q_proj.weight.dtype
357
+
358
+ # logger.warning_once(
359
+ # f"The input hidden states seems to be silently casted in float32, this might be related to"
360
+ # f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
361
+ # f" {target_dtype}."
362
+ # )
363
+
364
+ # query_states = query_states.to(target_dtype)
365
+ # key_states = key_states.to(target_dtype)
366
+ # value_states = value_states.to(target_dtype)
367
+
368
+ # decay_states_log.view is to match fla_chunk_simple_gla's requirements
369
+ #print("layer", self.layer_id, "pre ", bool(query_states.isnan().any()), bool(key_states.isnan().any()), bool(value_states.isnan().any()), bool(decay_states_log.isnan().any()))
370
+ #o = chunk_simple_gla(q.contiguous(), k.contiguous(), v.contiguous(), g.contiguous(), scale)
371
+
372
+ #print("layer", self.layer_id, "post", bool(query_states.isnan().any()), bool(key_states.isnan().any()), bool(value_states.isnan().any()), bool(decay_states_log.isnan().any()))
373
+
374
+ if not output_attentions:
375
+ attn_weights = torch.empty(0, device=x.device)
376
+
377
+ #attn_output = fla_chunk_simple_gla(query_states, key_states, value_states, decay_states_log.view(bsz, self.num_heads, q_len))[0]
378
+ #attn_output = chunk_gla(query_states, key_states, value_states, decay_states_log)[0]
379
+ attn_output = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log)[0]
380
+ attn_output = attn_output.transpose(1, 2).contiguous()
381
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
382
+ #attn_output = self.ln_x(attn_output)
383
+ attn_output = self.o_proj(attn_output * gate_states)
384
+ else:
385
+ attn_weights = (query_states * (key_states.size(-1) ** -0.5)) @ key_states.mT
386
+
387
+ decay_states_log = decay_states_log.mean(-1, keepdim=True)
388
+ attn_weights = attn_weights.float() * self.segsum(decay_states_log.float()) # NOTE - without the explicit cast to float ddp mismatched deepspeed here
389
+
390
+ attn_weights = attn_weights.to(query_states.dtype)
391
+ attn_output = torch.empty(0, device=x.device)
392
+
393
+ return attn_output, TimeMixState(last_state.wkv_state, last_state.shift_state), attn_weights #, past_key_value
394
+
395
+ def get_cmix_default_state(x:Tensor, config:Qwen2Config, requires_grad:bool):
396
+ B, T, C = x.size()
397
+ return ChannelMixState(
398
+ torch.zeros([B, C], dtype=x.dtype, device=x.device, requires_grad=requires_grad)
399
+ )
400
+
401
+ class CMix_qwen2(nn.Module):
402
+ def get_default_state_factory(self): return get_cmix_default_state
403
+
404
+ def __init__(self, config:Qwen2Config, layer_id):
405
+ super().__init__()
406
+ self.config = config
407
+ self.layer_id = layer_id
408
+
409
+ self.hidden_size = config.hidden_size
410
+ self.intermediate_size = config.intermediate_size
411
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
412
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
413
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
414
+ self.act_fn = torch.nn.SiLU() #ACT2FN[config.hidden_act]
415
+
416
+ def forward(self, x, last_model_state:ModelState):
417
+ last_state = last_model_state.block_states[self.layer_id].channel_mix_state
418
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)), last_state
419
+
420
+ class Qwen2DecoderLayer(nn.Module):
421
+ def __init__(self, config:Qwen2Config, layer_id:int):
422
+ super().__init__()
423
+
424
+ self.config = config
425
+
426
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
427
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
428
+
429
+ cmix = CMix_qwen2(config, layer_id)
430
+
431
+ #if args.attention_type == 'rwkv':
432
+ self.self_attn = TMix_qwen2rwkv(config, layer_id)
433
+ #else:
434
+ # self.self_attn = TMix_qwen2(config, layer_id)
435
+ self.default_time_mix_state_factory = self.self_attn.get_default_state_factory() if hasattr(self.self_attn, 'get_default_state_factory') else lambda x, c, r: TimeMixState()
436
+
437
+ self.teacher_attn = None
438
+ #if config.train is not None:
439
+ # if config.train.attention_distillation_stage in (1, 2):
440
+ # self.teacher_attn = TMix_qwen2(config, layer_id)
441
+
442
+ self.default_channel_mix_state_factory = cmix.get_default_state_factory() if hasattr(cmix, 'get_default_state_factory') else lambda x, c, r: ChannelMixState()
443
+ self.mlp = cmix
444
+
445
+ def forward(self, x:Tensor, last_model_state:ModelState, shared:Shared, output_attentions:bool, output_post_attention_hidden_states:bool):
446
+ s = last_model_state
447
+ if self.teacher_attn is not None:
448
+ dx, last_timemix_state, attentions = self.teacher_attn(self.input_layernorm(x), s, shared, output_attentions)
449
+ student_dx, student_last_timemix_state, student_attentions = self.self_attn(self.input_layernorm(x), s, shared, output_attentions)
450
+ else:
451
+ dx, last_timemix_state, attentions = self.self_attn(self.input_layernorm(x), s, shared, output_attentions)
452
+ student_dx, student_last_timemix_state, student_attentions = None, None, None
453
+ if output_post_attention_hidden_states:
454
+ post_attention_hidden_states = dx
455
+ student_post_attention_hidden_states = student_dx
456
+ else:
457
+ post_attention_hidden_states = torch.empty(0, device=x.device)
458
+ student_post_attention_hidden_states = torch.empty(0, device=x.device)
459
+
460
+ x = x + dx
461
+ dx, last_chanmix_state = self.mlp(self.post_attention_layernorm(x), s)
462
+ x = x + dx
463
+ return x, s, attentions, post_attention_hidden_states, student_attentions, student_post_attention_hidden_states
464
+
465
+ def ckpt(block:Qwen2DecoderLayer, *block_args):
466
+ # if block.training and block.config.train.grad_cp == 1 and 'fsdp' not in block.config.train.strategy: # FSDP has its own checkpointing wrapper
467
+ #if "deepspeed" in block.config.train.strategy:
468
+ # results = deepspeed.checkpointing.checkpoint(block, *block_args)
469
+ #else:
470
+ # NOTE - both deepspeed.checkpointing.checkpoint and use_reentrant=True failed miserably (bad loss) when used in conjunction with requires_grad=False params and grad_cp with deepspeed
471
+ results = torch.utils.checkpoint.checkpoint(block, *block_args, use_reentrant=False)
472
+ # else:
473
+ # results = block(*block_args)
474
+ return results
475
+
476
+ class Qwen2Decoder(nn.Module):
477
+ def __init__(self, config:Qwen2Config):
478
+ super().__init__()
479
+
480
+ self.config = config
481
+
482
+ self.shared = Shared()
483
+
484
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) #, config.vocab_padding_idx)
485
+ self.layers = nn.ModuleList(
486
+ [Qwen2DecoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
487
+ )
488
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
489
+
490
+ def forward_preamble(self, x, last_model_state:ModelState|None = None, ):
491
+ config = self.config
492
+
493
+ B, T, C = x.size()
494
+
495
+ shared = self.shared
496
+ #if config.rope is not None and shared.angles.size(0) == 0:
497
+ # shared.angles = generate_rotary_embedding(config.max_position_embeddings, config.head_size, config.rope_theta).to(self.norm.weight)
498
+
499
+ assert (shared.angles.size(0) == 0 or T <= shared.angles.size(0)) or (shared.bias_mask.size(0) == 0 or T <= shared.bias_mask.size(0))
500
+
501
+ # might need to be true in the future for BPTT support
502
+ requires_grad = self.training
503
+ if last_model_state is None:
504
+ last_model_state = ModelState()
505
+ for layer_id in range(config.num_hidden_layers):
506
+ layer = self.layers[layer_id]
507
+ last_model_state.block_states.append(BlockState(
508
+ layer.default_time_mix_state_factory(x, config, requires_grad),
509
+ layer.default_channel_mix_state_factory(x, config, requires_grad),
510
+ ))
511
+
512
+ return last_model_state
513
+
514
+ def forward(self, token_ids:Tensor|list, last_model_state:ModelState|None = None, output_hidden_states:bool=False, output_attentions:bool=False, output_post_attention_hidden_states:bool=False):
515
+ config = self.config
516
+ if isinstance(token_ids, Tensor):
517
+ B, T = token_ids.size()
518
+ else:
519
+ B = 1
520
+ T = len(token_ids)
521
+ token_ids = torch.tensor(token_ids, device=self.embed_tokens.weight.device, dtype=torch.long, requires_grad=False)[None, :]
522
+
523
+ x = self.embed_tokens(token_ids)
524
+
525
+ last_model_state = self.forward_preamble(x, last_model_state)
526
+
527
+ hidden_states_outputs, attentions_outputs, post_attention_hidden_states_outputs = (), (), ()
528
+ student_hidden_states_outputs, student_attentions_outputs, student_post_attention_hidden_states_outputs = (), (), ()
529
+ if output_hidden_states:
530
+ hidden_states_outputs += (x,)
531
+ student_hidden_states_outputs += (x,)
532
+ for decoder_layer in self.layers:
533
+ x, s, attentions, post_attention_hidden_states, student_attentions, student_post_attention_hidden_states = ckpt(decoder_layer, x, last_model_state, self.shared, output_attentions, output_post_attention_hidden_states)
534
+ hidden_states_outputs += (x,)
535
+ student_hidden_states_outputs += (x,)
536
+ if output_attentions:
537
+ attentions_outputs += (attentions,)
538
+ student_attentions_outputs += (student_attentions,)
539
+ if output_post_attention_hidden_states:
540
+ post_attention_hidden_states_outputs += (post_attention_hidden_states,)
541
+ student_post_attention_hidden_states_outputs += (student_post_attention_hidden_states,)
542
+
543
+ x = self.norm(x)
544
+ return LLMOutput(x, last_model_state, hidden_states_outputs, attentions_outputs, post_attention_hidden_states_outputs, student_attentions_outputs, student_post_attention_hidden_states_outputs)
545
+ #return x, last_model_state, hidden_states_outputs, attentions_outputs, post_attention_hidden_states_outputs # FIXME - not updating state at all
546
+
547
+ class Model_qwen2(nn.Module): # Qwen2CausalLM
548
+ def __init__(self, config:Qwen2Config):
549
+ super().__init__()
550
+
551
+ self.config = config
552
+
553
+ self.model = None
554
+
555
+ self.configure_model()
556
+
557
+ def configure_model(self):
558
+ if self.model is not None:
559
+ return
560
+
561
+ self.model = Qwen2Decoder(self.config)
562
+
563
+ self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
564
+
565
+ def forward(self, input_ids:Tensor|list, last_model_state:ModelState|None = None, output_hidden_states:bool=False, output_attentions:bool=False, output_post_attention_hidden_states:bool=False):
566
+ #print("teacher q min, max", float(self.model.layers[0].self_attn.q_proj.weight.min()), float(self.model.layers[0].self_attn.q_proj.weight.max()))
567
+ results = self.model(input_ids, last_model_state=last_model_state, output_hidden_states=output_hidden_states, output_attentions=output_attentions, output_post_attention_hidden_states=output_post_attention_hidden_states)
568
+ results.logits = self.lm_head(results.logits)
569
+ return results
570
+
571
+ def get_optim_groups(self):
572
+ # separates groups for weight decay and non-weight decay
573
+
574
+ config = self.config
575
+
576
+ weight_decay = 0.0 # train_config.weight_decay
577
+
578
+ # if train_config.attention_distillation_stage in (1, 2):
579
+ # self.requires_grad_(False)
580
+ # for decoder_layer in self.model.layers:
581
+ # if train_config.attention_distillation_stage == 1:
582
+ # decoder_layer.self_attn.time_maa_x.requires_grad_(True)
583
+ # decoder_layer.self_attn.time_maa_r.requires_grad_(True)
584
+ # decoder_layer.self_attn.time_maa_k.requires_grad_(True)
585
+ # decoder_layer.self_attn.time_maa_w.requires_grad_(True)
586
+ # decoder_layer.self_attn.time_maa_w1.requires_grad_(True)
587
+ # decoder_layer.self_attn.time_maa_w2.requires_grad_(True)
588
+ # decoder_layer.self_attn.time_decay.requires_grad_(True)
589
+ # decoder_layer.self_attn.time_decay_w1.requires_grad_(True)
590
+ # decoder_layer.self_attn.time_decay_w2.requires_grad_(True)
591
+ # # FIXME - wow we removed q, k here by accident and it.. helped??!?!
592
+ # # decoder_layer.self_attn.q_proj.requires_grad_(True)
593
+ # # decoder_layer.self_attn.k_proj.requires_grad_(True)
594
+ # elif train_config.attention_distillation_stage == 2:
595
+ # decoder_layer.self_attn.requires_grad_(True)
596
+
597
+ # FIXME - remove these for full training
598
+ # for decoder_layer in self.model.layers:
599
+ # decoder_layer.post_attention_layernorm.requires_grad_(False)
600
+ # decoder_layer.mlp.requires_grad_(False)
601
+ # self.model.embed_tokens.requires_grad_(False)
602
+ # self.model.norm.requires_grad_(False)
603
+ # self.lm_head.requires_grad_(False)
604
+
605
+ # JIT at last minute
606
+ for decoder_layer in self.model.layers:
607
+ decoder_layer.self_attn = TJIT(decoder_layer.self_attn)
608
+ decoder_layer.mlp = TJIT(decoder_layer.mlp)
609
+
610
+ lr_decay = set()
611
+ lr_1x = set()
612
+ lr_fp32 = set()
613
+ for n, p in self.named_parameters():
614
+ if not p.requires_grad:
615
+ continue
616
+ if 'lm_head.weight' in n or 'embed_tokens.weight' in n:
617
+ lr_fp32.add(n)
618
+ elif (len(p.squeeze().shape) >= 2) and (weight_decay > 0):
619
+ lr_decay.add(n)
620
+ else:
621
+ lr_1x.add(n)
622
+
623
+ param_dict = {n: p for n, p in self.named_parameters()}
624
+ param_check = list(lr_decay) + list(lr_1x) + list(lr_fp32)
625
+ #if not train_config.load_partial and (train_config.teacher is None or train_config.teacher.attention_distillation_stage ==3):
626
+ # assert sorted(param_dict) == sorted(param_check)
627
+
628
+ lr_decay = sorted(list(lr_decay))
629
+ lr_1x = sorted(list(lr_1x))
630
+ lr_fp32 = sorted(list(lr_fp32))
631
+
632
+ print('decay', lr_decay, '\n')
633
+ print('1x', lr_1x, '\n')
634
+ print('fp32', lr_fp32, '\n')
635
+
636
+
637
+ optim_groups = [
638
+ {"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "use_fp16": True, "my_lr_scale": 1.0, 'name':'lr_1x'},
639
+ ]
640
+ if len(lr_fp32) > 0:
641
+ optim_groups += [{"params": [param_dict[n] for n in lr_fp32], "weight_decay": weight_decay, "my_lr_scale": 1.0, 'name':'lr_tiny'}]
642
+ if len(lr_decay) > 0:
643
+ optim_groups += [{"params": [param_dict[n] for n in lr_decay], "weight_decay": weight_decay, "use_fp16": True, "my_lr_scale": 1.0, 'name':'lr_decay'}]
644
+
645
+ return optim_groups
646
+
647
+ # def _init_weights(self, module):
648
+ # std = 0.02 #self.config.initializer_range
649
+ # if isinstance(module, nn.Linear):
650
+ # module.weight.data.normal_(mean=0.0, std=std)
651
+ # if module.bias is not None:
652
+ # module.bias.data.zero_()
653
+ # elif isinstance(module, nn.Embedding):
654
+ # module.weight.data.normal_(mean=0.0, std=std)
655
+ # if module.padding_idx is not None:
656
+ # module.weight.data[module.padding_idx].zero_()
657
+
658
+ def init_all_weights(self):
659
+ # FIXME - we really need to init the gate here to identity (zero weights, ones bias) but instead we just won't init weights since they're all grabbed or set anyway
660
+ pass
661
+ # self.apply(self._init_weights)
662
+ # for n, p in self.named_parameters():
663
+ # requires_grad_temp = p.requires_grad
664
+ # p.requires_grad_(False)
665
+ # if n.endswith('.ln_x.weight'):
666
+ # layer_scale = (1+int(n.split('.')[2])) / self.config.model.n_layer
667
+ # print('.ln_x.weight layer', int(n.split('.')[2]), "scale", (layer_scale ** 0.7))
668
+ # p *= 0.0
669
+ # p += (layer_scale ** 0.7)
670
+ # p.requires_grad = requires_grad_temp
run_lm_eval.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+ import transformers # just for a bugfix for 0.4.2 of lm_eval
3
+
4
+ import torch
5
+ torch.backends.cudnn.benchmark = True
6
+ torch.backends.cudnn.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+
9
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
10
+ from configs import parse_cmdline_configs
11
+ from pydoc import locate
12
+
13
+ from lm_eval import evaluator
14
+ from lm_eval.models.huggingface import HFLM
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Callable
18
+
19
+ @dataclass
20
+ class CLI_Config:
21
+ tokenizer_path: str
22
+ model_path: str
23
+ attn_path: str = 'rwkv6attn.RWKV6Attention'
24
+ tasks: str = 'lambada_openai' # arc_challenge, arc_easy, headqa, openbookqa, hellaswag, winogrande, piqa, record, copa, storycloze_2016
25
+ bsz: int|str = 'auto'
26
+ precision: int | str = 'bf16'
27
+ num_fewshot: int = 0
28
+ attn_classes_path: str = 'transformers.models.qwen2.modeling_qwen2.QWEN2_ATTENTION_CLASSES' # 'transformers.models.llama.modeling_llama.LLAMA_ATTENTION_CLASSES'
29
+ seed: int | None = None
30
+ train:Any = None
31
+
32
+ config, errors = parse_cmdline_configs(sys.argv[1:], CLI_Config)
33
+ if errors != '':
34
+ print(errors)
35
+ exit()
36
+
37
+ match config.precision:
38
+ case 32:
39
+ dtype = torch.float32
40
+ case '32':
41
+ dtype = torch.float32
42
+ case 16:
43
+ dtype = torch.float16
44
+ case '16':
45
+ dtype = torch.float16
46
+ case 'bf16':
47
+ dtype = torch.bfloat16
48
+ case _:
49
+ print("Bad precision type specified")
50
+ exit()
51
+
52
+ # avoid 1000 huggingface warnings "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...""
53
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
54
+
55
+ print(f'Loading model - {config.model_path}')
56
+
57
+ model_config = AutoConfig.from_pretrained(config.model_path)
58
+
59
+ if config.model_path.startswith('.') or config.model_path.startswith('/'):
60
+ # replace attention classes
61
+ ReplacementSelfAttentionType = locate(config.attn_path)
62
+ assert isinstance(ReplacementSelfAttentionType, Callable)
63
+ attn_classes_dict = locate(config.attn_classes_path)
64
+ assert isinstance(attn_classes_dict, dict), 'could not find attention classes dict at path provided'
65
+ for key in list(attn_classes_dict.keys()):
66
+ attn_classes_dict[key] = ReplacementSelfAttentionType
67
+
68
+ model = AutoModelForCausalLM.from_pretrained(config.model_path, config=model_config, torch_dtype=dtype, device_map='cuda')
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
71
+
72
+ #device = 'cuda'
73
+ #model = model.to(device=device, dtype=dtype)
74
+ model.eval()
75
+
76
+ eval_tasks = config.tasks.split(',')
77
+
78
+ if config.seed is None:
79
+ config.seed = 1234
80
+
81
+ adapter = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=config.bsz)
82
+ with torch.no_grad():
83
+ with torch.amp.autocast(device_type='cuda', dtype=dtype):
84
+ results = evaluator.simple_evaluate(
85
+ model=adapter,
86
+ tasks=eval_tasks,
87
+ #provide_description=False,
88
+ num_fewshot=config.num_fewshot,
89
+ limit=None,
90
+ bootstrap_iters=10000,
91
+ numpy_random_seed = config.seed,
92
+ torch_random_seed = config.seed,
93
+ fewshot_random_seed = config.seed,
94
+ )
95
+
96
+ print(results['results'])
tokenization_rwkv6qwen2.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
2
+
3
+ class RWKV6Qwen2Tokenizer(Qwen2Tokenizer):
4
+ pass
tokenization_rwkv6qwen2_fast.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
2
+
3
+ class RWKV6Qwen2TokenizerFast(Qwen2TokenizerFast):
4
+ pass
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ }
181
+ },
182
+ "additional_special_tokens": [
183
+ "<|im_start|>",
184
+ "<|im_end|>",
185
+ "<|object_ref_start|>",
186
+ "<|object_ref_end|>",
187
+ "<|box_start|>",
188
+ "<|box_end|>",
189
+ "<|quad_start|>",
190
+ "<|quad_end|>",
191
+ "<|vision_start|>",
192
+ "<|vision_end|>",
193
+ "<|vision_pad|>",
194
+ "<|image_pad|>",
195
+ "<|video_pad|>"
196
+ ],
197
+ "bos_token": null,
198
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
199
+ "clean_up_tokenization_spaces": false,
200
+ "eos_token": "<|im_end|>",
201
+ "errors": "replace",
202
+ "model_max_length": 131072,
203
+ "pad_token": "<|endoftext|>",
204
+ "split_special_tokens": false,
205
+ "tokenizer_class": "Qwen2Tokenizer",
206
+ "unk_token": null
207
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff