Add files using upload-large-folder tool
Browse files- config.json +41 -0
- configuration_rwkv6qwen2.py +206 -0
- examine_ckpt.py +25 -0
- generate.py +119 -0
- generation_config.json +14 -0
- merges.txt +0 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +682 -0
- modeling_rwkv6qwen2.py +1336 -0
- qwen2.py +670 -0
- run_lm_eval.py +96 -0
- tokenization_rwkv6qwen2.py +4 -0
- tokenization_rwkv6qwen2_fast.py +4 -0
- tokenizer.json +0 -0
- tokenizer_config.json +207 -0
- vocab.json +0 -0
config.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"RWKV6Qwen2ForCausalLM"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_rwkv6qwen2.RWKV6Qwen2Config",
|
7 |
+
"AutoModelForCausalLM": "modeling_rwkv6qwen2.RWKV6Qwen2ForCausalLM"
|
8 |
+
},
|
9 |
+
"attention_bias": true,
|
10 |
+
"attention_dropout": 0.0,
|
11 |
+
"attention_output_bias": false,
|
12 |
+
"balance_state": true,
|
13 |
+
"bos_token_id": 151643,
|
14 |
+
"eos_token_id": 151643,
|
15 |
+
"gate_rank_type": 1,
|
16 |
+
"groupnorm_att": false,
|
17 |
+
"hidden_act": "silu",
|
18 |
+
"hidden_size": 3584,
|
19 |
+
"initializer_range": 0.02,
|
20 |
+
"intermediate_size": 18944,
|
21 |
+
"lora_rank_decay": 96,
|
22 |
+
"lora_rank_tokenshift": 96,
|
23 |
+
"lora_rank_gate": 0,
|
24 |
+
"max_position_embeddings": 131072,
|
25 |
+
"max_window_layers": 28,
|
26 |
+
"model_type": "rwkv6qwen2",
|
27 |
+
"num_attention_heads": 28,
|
28 |
+
"num_hidden_layers": 28,
|
29 |
+
"num_key_value_heads": 4,
|
30 |
+
"rms_norm_eps": 1e-06,
|
31 |
+
"rope_theta": 1000000.0,
|
32 |
+
"sliding_window": 131072,
|
33 |
+
"tie_word_embeddings": false,
|
34 |
+
"torch_dtype": "bfloat16",
|
35 |
+
"transformers_version": "4.43.1",
|
36 |
+
"use_cache": true,
|
37 |
+
"use_rope": false,
|
38 |
+
"use_tokenshift": true,
|
39 |
+
"use_sliding_window": false,
|
40 |
+
"vocab_size": 152064
|
41 |
+
}
|
configuration_rwkv6qwen2.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""RWKV6Qwen2 model configuration"""
|
16 |
+
|
17 |
+
from transformers.configuration_utils import PretrainedConfig
|
18 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class RWKV6Qwen2Config(PretrainedConfig):
|
26 |
+
r"""
|
27 |
+
This is the configuration class to store the configuration of a [`RWKV6Qwen2Model`]. It is used to instantiate a
|
28 |
+
RWKV6Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
29 |
+
with the defaults will yield a similar configuration to that of
|
30 |
+
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
|
31 |
+
|
32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
33 |
+
documentation from [`PretrainedConfig`] for more information.
|
34 |
+
|
35 |
+
|
36 |
+
Args:
|
37 |
+
vocab_size (`int`, *optional*, defaults to 151936):
|
38 |
+
Vocabulary size of the RWKV6Qwen2 model. Defines the number of different tokens that can be represented by the
|
39 |
+
`inputs_ids` passed when calling [`RWKV6Qwen2Model`]
|
40 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
41 |
+
Dimension of the hidden representations.
|
42 |
+
intermediate_size (`int`, *optional*, defaults to 22016):
|
43 |
+
Dimension of the MLP representations.
|
44 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
45 |
+
Number of hidden layers in the Transformer encoder.
|
46 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
47 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
48 |
+
num_key_value_heads (`int`, *optional*, defaults to 32):
|
49 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
50 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
51 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
52 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
53 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
54 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
55 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
56 |
+
The non-linear activation function (function or string) in the decoder.
|
57 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
58 |
+
The maximum sequence length that this model might ever be used with.
|
59 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
60 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
61 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
62 |
+
The epsilon used by the rms normalization layers.
|
63 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
64 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
65 |
+
relevant if `config.is_decoder=True`.
|
66 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
67 |
+
Whether the model's input and output word embeddings should be tied.
|
68 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
69 |
+
The base period of the RoPE embeddings.
|
70 |
+
rope_scaling (`Dict`, *optional*):
|
71 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
72 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
73 |
+
accordingly.
|
74 |
+
Expected contents:
|
75 |
+
`rope_type` (`str`):
|
76 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
77 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
78 |
+
`factor` (`float`, *optional*):
|
79 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
80 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
81 |
+
original maximum pre-trained length.
|
82 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
83 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
84 |
+
pretraining.
|
85 |
+
`attention_factor` (`float`, *optional*):
|
86 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
87 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
88 |
+
`factor` field to infer the suggested value.
|
89 |
+
`beta_fast` (`float`, *optional*):
|
90 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
91 |
+
ramp function. If unspecified, it defaults to 32.
|
92 |
+
`beta_slow` (`float`, *optional*):
|
93 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
94 |
+
ramp function. If unspecified, it defaults to 1.
|
95 |
+
`short_factor` (`List[float]`, *optional*):
|
96 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
97 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
98 |
+
size divided by the number of attention heads divided by 2
|
99 |
+
`long_factor` (`List[float]`, *optional*):
|
100 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
101 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
102 |
+
size divided by the number of attention heads divided by 2
|
103 |
+
`low_freq_factor` (`float`, *optional*):
|
104 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
105 |
+
`high_freq_factor` (`float`, *optional*):
|
106 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
107 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
108 |
+
Whether to use sliding window attention.
|
109 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
110 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
111 |
+
max_window_layers (`int`, *optional*, defaults to 28):
|
112 |
+
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
113 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
114 |
+
The dropout ratio for the attention probabilities.
|
115 |
+
|
116 |
+
```python
|
117 |
+
>>> from transformers import RWKV6Qwen2Model, RWKV6Qwen2Config
|
118 |
+
|
119 |
+
>>> # Initializing a RWKV6Qwen2 style configuration
|
120 |
+
>>> configuration = RWKV6Qwen2Config()
|
121 |
+
|
122 |
+
>>> # Initializing a model from the RWKV6Qwen2-7B style configuration
|
123 |
+
>>> model = RWKV6Qwen2Model(configuration)
|
124 |
+
|
125 |
+
>>> # Accessing the model configuration
|
126 |
+
>>> configuration = model.config
|
127 |
+
```"""
|
128 |
+
|
129 |
+
model_type = "rwkv6qwen2"
|
130 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
131 |
+
|
132 |
+
def __init__(
|
133 |
+
self,
|
134 |
+
vocab_size=151936,
|
135 |
+
hidden_size=4096,
|
136 |
+
intermediate_size=22016,
|
137 |
+
num_hidden_layers=32,
|
138 |
+
num_attention_heads=32,
|
139 |
+
num_key_value_heads=32,
|
140 |
+
lora_rank_tokenshift=None,
|
141 |
+
lora_rank_decay=None,
|
142 |
+
hidden_act="silu",
|
143 |
+
max_position_embeddings=32768,
|
144 |
+
initializer_range=0.02,
|
145 |
+
rms_norm_eps=1e-6,
|
146 |
+
use_cache=True,
|
147 |
+
tie_word_embeddings=False,
|
148 |
+
use_rope=False,
|
149 |
+
rope_theta=10000.0,
|
150 |
+
rope_scaling=None,
|
151 |
+
use_sliding_window=False,
|
152 |
+
sliding_window=4096,
|
153 |
+
max_window_layers=28,
|
154 |
+
attention_dropout=0.0,
|
155 |
+
attention_bias=True,
|
156 |
+
attention_output_bias=False,
|
157 |
+
gate_rank_type=1,
|
158 |
+
lora_rank_gate=None,
|
159 |
+
balance_state=True,
|
160 |
+
groupnorm_att=False,
|
161 |
+
use_tokenshift=True,
|
162 |
+
**kwargs,
|
163 |
+
):
|
164 |
+
self.vocab_size = vocab_size
|
165 |
+
self.max_position_embeddings = max_position_embeddings
|
166 |
+
self.hidden_size = hidden_size
|
167 |
+
self.intermediate_size = intermediate_size
|
168 |
+
self.num_hidden_layers = num_hidden_layers
|
169 |
+
self.num_attention_heads = num_attention_heads
|
170 |
+
self.use_sliding_window = use_sliding_window
|
171 |
+
self.sliding_window = sliding_window if use_sliding_window else None
|
172 |
+
self.max_window_layers = max_window_layers
|
173 |
+
|
174 |
+
# for backward compatibility
|
175 |
+
if num_key_value_heads is None:
|
176 |
+
num_key_value_heads = num_attention_heads
|
177 |
+
|
178 |
+
self.num_key_value_heads = num_key_value_heads
|
179 |
+
self.lora_rank_tokenshift = lora_rank_tokenshift
|
180 |
+
self.lora_rank_decay = lora_rank_decay
|
181 |
+
self.hidden_act = hidden_act
|
182 |
+
self.initializer_range = initializer_range
|
183 |
+
self.rms_norm_eps = rms_norm_eps
|
184 |
+
self.use_cache = use_cache
|
185 |
+
self.use_rope = use_rope
|
186 |
+
self.rope_theta = rope_theta
|
187 |
+
self.rope_scaling = rope_scaling
|
188 |
+
self.attention_dropout = attention_dropout
|
189 |
+
# Validate the correctness of rotary position embeddings parameters
|
190 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
191 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
192 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
193 |
+
rope_config_validation(self)
|
194 |
+
|
195 |
+
self.attention_bias = attention_bias
|
196 |
+
self.attention_output_bias = attention_output_bias
|
197 |
+
self.gate_rank_type = gate_rank_type
|
198 |
+
self.lora_rank_gate = lora_rank_gate
|
199 |
+
self.balance_state = balance_state
|
200 |
+
self.groupnorm_att = groupnorm_att
|
201 |
+
self.use_tokenshift = use_tokenshift
|
202 |
+
|
203 |
+
super().__init__(
|
204 |
+
tie_word_embeddings=tie_word_embeddings,
|
205 |
+
**kwargs,
|
206 |
+
)
|
examine_ckpt.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
from collections import OrderedDict
|
5 |
+
import re
|
6 |
+
from safetensors.torch import load_file
|
7 |
+
|
8 |
+
if len(sys.argv) != 2:
|
9 |
+
print(f"Examines checkpoint keys")
|
10 |
+
print("Usage: python examine_ckpt.py in_file")
|
11 |
+
exit()
|
12 |
+
|
13 |
+
model_path = sys.argv[1]
|
14 |
+
|
15 |
+
print("Loading file...")
|
16 |
+
if model_path.lower().endswith('.safetensors'):
|
17 |
+
state_dict = load_file(model_path)
|
18 |
+
else:
|
19 |
+
state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
|
20 |
+
|
21 |
+
for name, p in state_dict.items():
|
22 |
+
if p.numel() == 0:
|
23 |
+
print(name, p.dtype, p.shape)
|
24 |
+
else:
|
25 |
+
print(name, p.dtype, p.shape, float(p.min()), float(p.max()))
|
generate.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
torch.backends.cudnn.benchmark = True
|
5 |
+
torch.backends.cudnn.allow_tf32 = True
|
6 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
7 |
+
|
8 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
9 |
+
from configs import parse_cmdline_configs
|
10 |
+
from pydoc import locate
|
11 |
+
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from typing import Any, Callable
|
14 |
+
|
15 |
+
moby_dick = """Call me Ishmael. Some years ago—never mind how long precisely—having little or no money in my purse, and nothing particular to interest me on shore, I thought I would sail about a little and see the watery part of the world. It is a way I have of driving off the spleen and regulating the circulation. Whenever I find myself growing grim about the mouth; whenever it is a damp, drizzly November in my soul; whenever I find myself involuntarily pausing before coffin warehouses, and bringing up the rear of every funeral I meet; and especially whenever my hypos get such an upper hand of me, that it requires a strong moral principle to prevent me from deliberately stepping into the street, and methodically knocking people’s hats off—then, I account it high time to get to sea as soon as I can. This is my substitute for pistol and ball. With a philosophical flourish Cato throws himself upon his sword; I quietly take to the ship. There is nothing surprising in this. If they but knew it, almost all men in their degree, some time or other, cherish very nearly the same feelings towards the ocean with me.
|
16 |
+
|
17 |
+
There now is your insular city of the Manhattoes, belted round by wharves as Indian isles by coral reefs—commerce surrounds it with her surf. Right and left, the streets take you waterward. Its extreme downtown is the battery, where that noble mole is washed by waves, and cooled by breezes, which a few hours previous were out of sight of land. Look at the crowds of water-gazers there.
|
18 |
+
|
19 |
+
Circumambulate the city of a dreamy Sabbath afternoon. Go from Corlears Hook to Coenties Slip, and from thence, by Whitehall, northward. What do you see?—Posted like silent sentinels all around the town, stand thousands upon thousands of mortal men fixed in ocean reveries. Some leaning against the spiles; some seated upon the pier-heads; some looking over the bulwarks of ships from China; some high aloft in the rigging, as if striving to get a still better seaward peep. But these are all landsmen; of week days pent up in lath and plaster—tied to counters, nailed to benches, clinched to desks. How then is this? Are the green fields gone? What do they here?
|
20 |
+
|
21 |
+
But look! here come more crowds, pacing straight for the water, and seemingly bound for a dive. Strange! Nothing will content them but the extremest limit of the land; loitering under the shady lee of yonder warehouses will not suffice. No. They must get just as nigh the water as they possibly can without falling in. And there they stand—miles of them—leagues. Inlanders all, they come from lanes and alleys, streets and avenues—north, east, south, and west. Yet here they all unite. Tell me, does the magnetic virtue of the needles of the compasses of all those ships attract them thither?
|
22 |
+
|
23 |
+
Once more. Say you are in the country; in some high land of lakes. Take almost any path you please, and ten to one it carries you down in a dale, and leaves you there by a pool in the stream. There is magic in it. Let the most absent-minded of men be plunged in his deepest reveries—stand that man on his legs, set his feet a-going, and he will infallibly lead you to water, if water there be in all that region. Should you ever be athirst in the great American desert, try this experiment, if your caravan happen to be supplied with a metaphysical professor. Yes, as every one knows, meditation and water are wedded for ever.
|
24 |
+
|
25 |
+
But here is an artist. He desires to paint you the dreamiest, shadiest, quietest, most enchanting bit of romantic landscape in all the valley of the Saco. What is the chief element he employs? There stand his trees, each with a hollow trunk, as if a hermit and a crucifix were within; and here sleeps his meadow, and there sleep his cattle; and up from yonder cottage goes a sleepy smoke. Deep into distant woodlands winds a mazy way, reaching to overlapping spurs of mountains bathed in their hill-side blue. But though the picture lies thus tranced, and though this pine-tree shakes down its sighs like leaves upon this shepherd’s head, yet all were vain, unless the shepherd’s eye were fixed upon the magic stream before him. Go visit the Prairies in June, when for scores on scores of miles you wade knee-deep among Tiger-lilies—what is the one charm wanting?—Water—there is not a drop of water there! Were Niagara but a cataract of sand, would you travel your thousand miles to see it? Why did the poor poet of Tennessee, upon suddenly receiving two handfuls of silver, deliberate whether to buy him a coat, which he sadly needed, or invest his money in a pedestrian trip to Rockaway Beach? Why is almost every robust healthy boy with a robust healthy soul in him, at some time or other crazy to go to sea? Why upon your first voyage as a passenger, did you yourself feel such a mystical vibration, when first told that you and your ship were now out of sight of land? Why did the old Persians hold the sea holy? Why did the Greeks give it a separate deity, and own brother of Jove? Surely all this is not without meaning. And still deeper the meaning of that story of Narcissus, who because he could not grasp the tormenting, mild image he saw in the fountain, plunged into it and was drowned. But that same image, we ourselves see in all rivers and oceans. It is the image of the ungraspable phantom of life; and this is the key to it all.
|
26 |
+
|
27 |
+
Now, when I say that I am in the habit of going to sea whenever I begin to grow hazy about the eyes, and begin to be over conscious of my lungs, I do not mean to have it inferred that I ever go to sea as a passenger. For to go as a passenger you must needs have a purse, and a purse is but a rag unless you have something in it. Besides, passengers get sea-sick—grow quarrelsome—don’t sleep of nights—do not enjoy themselves much, as a general thing;—no, I never go as a passenger; nor, though I am something of a salt, do I ever go to sea as a Commodore, or a Captain, or a Cook. I abandon the glory and distinction of such offices to those who like them. For my part, I abominate all honorable respectable toils, trials, and tribulations of every kind whatsoever. It is quite as much as I can do to take care of myself, without taking care of ships, barques, brigs, schooners, and what not. And as for going as cook,—though I confess there is considerable glory in that, a cook being a sort of officer on ship-board—yet, somehow, I never fancied broiling fowls;—though once broiled, judiciously buttered, and judgmatically salted and peppered, there is no one who will speak more respectfully, not to say reverentially, of a broiled fowl than I will. It is out of the idolatrous dotings of the old Egyptians upon broiled ibis and roasted river horse, that you see the mummies of those creatures in their huge bake-houses the pyramids.
|
28 |
+
|
29 |
+
No, when I go to sea, I go as a simple sailor, right before the mast, plumb down into the forecastle, aloft there to the royal mast-head. True, they rather order me about some, and make me jump from spar to spar, like a grasshopper in a May meadow. And at first, this sort of thing is unpleasant enough. It touches one’s sense of honor, particularly if you come of an old established family in the land, the Van Rensselaers, or Randolphs, or Hardicanutes. And more than all, if just previous to putting your hand into the tar-pot, you have been lording it as a country schoolmaster, making the tallest boys stand in awe of you. The transition is a keen one, I assure you, from a schoolmaster to a sailor, and requires a strong decoction of Seneca and the Stoics to enable you to grin and bear it. But even this wears off in time.
|
30 |
+
|
31 |
+
What of it, if some old hunks of a sea-captain orders me to get a broom and sweep down the decks? What does that indignity amount to, weighed, I mean, in the scales of the New Testament? Do you think the archangel Gabriel thinks anything the less of me, because I promptly and respectfully obey that old hunks in that particular instance? Who ain’t a slave? Tell me that. Well, then, however the old sea-captains may order me about—however they may thump and punch me about, I have the satisfaction of knowing that it is all right; that everybody else is one way or other served in much the same way—either in a physical or metaphysical point of view, that is; and so the universal thump is passed round, and all hands should rub each other’s shoulder-blades, and be content.
|
32 |
+
|
33 |
+
Again, I always go to sea as a sailor, because they make a point of paying me for my trouble, whereas they never pay passengers a single penny that I ever heard of. On the contrary, passengers themselves must pay. And there is all the difference in the world between paying and being paid. The act of paying is perhaps the most uncomfortable infliction that the two orchard thieves entailed upon us. But being paid,—what will compare with it? The urbane activity with which a man receives money is really marvellous, considering that we so earnestly believe money to be the root of all earthly ills, and that on no account can a monied man enter heaven. Ah! how cheerfully we consign ourselves to perdition!
|
34 |
+
|
35 |
+
Finally, I always go to sea as a sailor, because """
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class CLI_Config:
|
40 |
+
tokenizer_path: str
|
41 |
+
model_path: str
|
42 |
+
attn_path: str = 'rwkv6attn.RWKV6Attention'
|
43 |
+
prompt:str = 'How many quarts are in a gallon?'
|
44 |
+
max_len:int = 30
|
45 |
+
attempts:int = 1
|
46 |
+
precision: int | str = 'bf16'
|
47 |
+
attn_classes_path: str = 'transformers.models.qwen2.modeling_qwen2.QWEN2_ATTENTION_CLASSES' # 'transformers.models.llama.modeling_llama.LLAMA_ATTENTION_CLASSES'
|
48 |
+
seed: int | None = None
|
49 |
+
train:Any = None
|
50 |
+
|
51 |
+
config, errors = parse_cmdline_configs(sys.argv[1:], CLI_Config)
|
52 |
+
if errors != '':
|
53 |
+
print(errors)
|
54 |
+
exit()
|
55 |
+
|
56 |
+
match config.precision:
|
57 |
+
case 32:
|
58 |
+
dtype = torch.float32
|
59 |
+
case '32':
|
60 |
+
dtype = torch.float32
|
61 |
+
case 16:
|
62 |
+
dtype = torch.float16
|
63 |
+
case '16':
|
64 |
+
dtype = torch.float16
|
65 |
+
case 'bf16':
|
66 |
+
dtype = torch.bfloat16
|
67 |
+
case _:
|
68 |
+
print("Bad precision type specified")
|
69 |
+
exit()
|
70 |
+
|
71 |
+
# avoid 1000 huggingface warnings "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...""
|
72 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
73 |
+
|
74 |
+
print(f'Loading model - {config.model_path}')
|
75 |
+
|
76 |
+
model_config = AutoConfig.from_pretrained(config.model_path, trust_remote_code=True)
|
77 |
+
|
78 |
+
# if config.model_path.startswith('.'):
|
79 |
+
# # replace attention classes
|
80 |
+
# ReplacementSelfAttentionType = locate(config.attn_path)
|
81 |
+
# assert isinstance(ReplacementSelfAttentionType, Callable)
|
82 |
+
# attn_classes_dict = locate(config.attn_classes_path)
|
83 |
+
# assert isinstance(attn_classes_dict, dict), 'could not find attention classes dict at path provided'
|
84 |
+
# for key in list(attn_classes_dict.keys()):
|
85 |
+
# attn_classes_dict[key] = ReplacementSelfAttentionType
|
86 |
+
|
87 |
+
model = AutoModelForCausalLM.from_pretrained(config.model_path, config=model_config, torch_dtype=dtype, device_map='cuda', trust_remote_code=True)
|
88 |
+
|
89 |
+
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path, trust_remote_code=True)
|
90 |
+
|
91 |
+
#device = 'cuda'
|
92 |
+
#model = model.to(device=device, dtype=dtype)
|
93 |
+
model.eval()
|
94 |
+
|
95 |
+
if config.seed is None:
|
96 |
+
config.seed = 1234
|
97 |
+
|
98 |
+
from transformers import AutoTokenizer, Qwen2ForCausalLM, set_seed
|
99 |
+
|
100 |
+
set_seed(config.seed)
|
101 |
+
|
102 |
+
text = config.prompt
|
103 |
+
|
104 |
+
messages = [
|
105 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
106 |
+
{"role": "user", "content": config.prompt}
|
107 |
+
]
|
108 |
+
text = tokenizer.apply_chat_template(
|
109 |
+
messages,
|
110 |
+
tokenize=False,
|
111 |
+
add_generation_prompt=True
|
112 |
+
)
|
113 |
+
inputs = tokenizer(text, return_tensors="pt").to('cuda')
|
114 |
+
|
115 |
+
# Generate
|
116 |
+
for i in range(config.attempts):
|
117 |
+
print(f"Attempt {i+1}:")
|
118 |
+
generate_ids = model.generate(inputs.input_ids, max_new_tokens=config.max_len, use_cache=True, do_sample=True, temperature=1.0, top_p=1.0)#, typical_p=0.95)#top_p=0.7, repetition_penalty=0.25)
|
119 |
+
print(tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, use_cache=False)[0])
|
generation_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token_id": 151643,
|
3 |
+
"pad_token_id": 151643,
|
4 |
+
"do_sample": true,
|
5 |
+
"eos_token_id": [
|
6 |
+
151645,
|
7 |
+
151643
|
8 |
+
],
|
9 |
+
"repetition_penalty": 1.05,
|
10 |
+
"temperature": 0.7,
|
11 |
+
"top_p": 0.8,
|
12 |
+
"top_k": 20,
|
13 |
+
"transformers_version": "4.37.0"
|
14 |
+
}
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:662ccdc789fe510fdf88cda5727f77061d133d9d31f0591d7a008b8e8add6b04
|
3 |
+
size 4955133312
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b9cdf366e3cca4c985365bb2571668ce7d2d3649071bf270e1166a18208d55c4
|
3 |
+
size 4865370824
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:367ae7e5f35819bf5d72cfb4159fd7d560a7cf943380d53a0e64d8e50b38d61c
|
3 |
+
size 4865370920
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cfe5f39f1ff90b14ab41d7b12f852b3e506c40b2230a40735bf181360de839bc
|
3 |
+
size 1497374264
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 16183172096
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
7 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
8 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
9 |
+
"model.layers.0.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
|
10 |
+
"model.layers.0.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
|
11 |
+
"model.layers.0.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
|
12 |
+
"model.layers.0.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
|
13 |
+
"model.layers.0.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
|
14 |
+
"model.layers.0.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
|
15 |
+
"model.layers.0.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
|
16 |
+
"model.layers.0.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
|
17 |
+
"model.layers.0.self_attn.time_decay": "model-00001-of-00004.safetensors",
|
18 |
+
"model.layers.0.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
|
19 |
+
"model.layers.0.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
|
20 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
21 |
+
"model.layers.0.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
22 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
23 |
+
"model.layers.0.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
24 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
25 |
+
"model.layers.0.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
26 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
27 |
+
"model.layers.0.self_attn.gate.weight": "model-00001-of-00004.safetensors",
|
28 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
29 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
30 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
31 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
32 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
33 |
+
"model.layers.1.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
|
34 |
+
"model.layers.1.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
|
35 |
+
"model.layers.1.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
|
36 |
+
"model.layers.1.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
|
37 |
+
"model.layers.1.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
|
38 |
+
"model.layers.1.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
|
39 |
+
"model.layers.1.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
|
40 |
+
"model.layers.1.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
|
41 |
+
"model.layers.1.self_attn.time_decay": "model-00001-of-00004.safetensors",
|
42 |
+
"model.layers.1.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
|
43 |
+
"model.layers.1.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
|
44 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
45 |
+
"model.layers.1.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
46 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
47 |
+
"model.layers.1.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
48 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
49 |
+
"model.layers.1.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
50 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
51 |
+
"model.layers.1.self_attn.gate.weight": "model-00001-of-00004.safetensors",
|
52 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
53 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
54 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
55 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
56 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
57 |
+
"model.layers.2.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
|
58 |
+
"model.layers.2.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
|
59 |
+
"model.layers.2.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
|
60 |
+
"model.layers.2.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
|
61 |
+
"model.layers.2.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
|
62 |
+
"model.layers.2.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
|
63 |
+
"model.layers.2.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
|
64 |
+
"model.layers.2.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
|
65 |
+
"model.layers.2.self_attn.time_decay": "model-00001-of-00004.safetensors",
|
66 |
+
"model.layers.2.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
|
67 |
+
"model.layers.2.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
|
68 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
69 |
+
"model.layers.2.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
70 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
71 |
+
"model.layers.2.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
72 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
73 |
+
"model.layers.2.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
74 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
75 |
+
"model.layers.2.self_attn.gate.weight": "model-00001-of-00004.safetensors",
|
76 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
77 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
78 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
79 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
80 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
81 |
+
"model.layers.3.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
|
82 |
+
"model.layers.3.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
|
83 |
+
"model.layers.3.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
|
84 |
+
"model.layers.3.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
|
85 |
+
"model.layers.3.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
|
86 |
+
"model.layers.3.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
|
87 |
+
"model.layers.3.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
|
88 |
+
"model.layers.3.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
|
89 |
+
"model.layers.3.self_attn.time_decay": "model-00001-of-00004.safetensors",
|
90 |
+
"model.layers.3.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
|
91 |
+
"model.layers.3.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
|
92 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
93 |
+
"model.layers.3.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
94 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
95 |
+
"model.layers.3.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
96 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
97 |
+
"model.layers.3.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
98 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
99 |
+
"model.layers.3.self_attn.gate.weight": "model-00001-of-00004.safetensors",
|
100 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
101 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
102 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
103 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
104 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
105 |
+
"model.layers.4.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
|
106 |
+
"model.layers.4.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
|
107 |
+
"model.layers.4.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
|
108 |
+
"model.layers.4.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
|
109 |
+
"model.layers.4.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
|
110 |
+
"model.layers.4.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
|
111 |
+
"model.layers.4.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
|
112 |
+
"model.layers.4.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
|
113 |
+
"model.layers.4.self_attn.time_decay": "model-00001-of-00004.safetensors",
|
114 |
+
"model.layers.4.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
|
115 |
+
"model.layers.4.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
|
116 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
117 |
+
"model.layers.4.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
118 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
119 |
+
"model.layers.4.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
120 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
121 |
+
"model.layers.4.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
122 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
123 |
+
"model.layers.4.self_attn.gate.weight": "model-00001-of-00004.safetensors",
|
124 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
125 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
126 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
127 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
128 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
129 |
+
"model.layers.5.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
|
130 |
+
"model.layers.5.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
|
131 |
+
"model.layers.5.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
|
132 |
+
"model.layers.5.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
|
133 |
+
"model.layers.5.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
|
134 |
+
"model.layers.5.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
|
135 |
+
"model.layers.5.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
|
136 |
+
"model.layers.5.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
|
137 |
+
"model.layers.5.self_attn.time_decay": "model-00001-of-00004.safetensors",
|
138 |
+
"model.layers.5.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
|
139 |
+
"model.layers.5.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
|
140 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
141 |
+
"model.layers.5.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
142 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
143 |
+
"model.layers.5.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
144 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
145 |
+
"model.layers.5.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
146 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
147 |
+
"model.layers.5.self_attn.gate.weight": "model-00001-of-00004.safetensors",
|
148 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
149 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
150 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
151 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
152 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
153 |
+
"model.layers.6.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
|
154 |
+
"model.layers.6.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
|
155 |
+
"model.layers.6.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
|
156 |
+
"model.layers.6.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
|
157 |
+
"model.layers.6.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
|
158 |
+
"model.layers.6.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
|
159 |
+
"model.layers.6.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
|
160 |
+
"model.layers.6.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
|
161 |
+
"model.layers.6.self_attn.time_decay": "model-00001-of-00004.safetensors",
|
162 |
+
"model.layers.6.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
|
163 |
+
"model.layers.6.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
|
164 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
165 |
+
"model.layers.6.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
166 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
167 |
+
"model.layers.6.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
168 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
169 |
+
"model.layers.6.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
170 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
171 |
+
"model.layers.6.self_attn.gate.weight": "model-00001-of-00004.safetensors",
|
172 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
173 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
174 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
175 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
176 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
177 |
+
"model.layers.7.self_attn.time_maa_x": "model-00001-of-00004.safetensors",
|
178 |
+
"model.layers.7.self_attn.time_maa_r": "model-00001-of-00004.safetensors",
|
179 |
+
"model.layers.7.self_attn.time_maa_k": "model-00001-of-00004.safetensors",
|
180 |
+
"model.layers.7.self_attn.time_maa_v": "model-00001-of-00004.safetensors",
|
181 |
+
"model.layers.7.self_attn.time_maa_w": "model-00001-of-00004.safetensors",
|
182 |
+
"model.layers.7.self_attn.time_maa_g": "model-00001-of-00004.safetensors",
|
183 |
+
"model.layers.7.self_attn.time_maa_w2": "model-00001-of-00004.safetensors",
|
184 |
+
"model.layers.7.self_attn.time_maa_w1": "model-00001-of-00004.safetensors",
|
185 |
+
"model.layers.7.self_attn.time_decay": "model-00001-of-00004.safetensors",
|
186 |
+
"model.layers.7.self_attn.time_decay_w1": "model-00001-of-00004.safetensors",
|
187 |
+
"model.layers.7.self_attn.time_decay_w2": "model-00001-of-00004.safetensors",
|
188 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
189 |
+
"model.layers.7.self_attn.q_proj.bias": "model-00001-of-00004.safetensors",
|
190 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
191 |
+
"model.layers.7.self_attn.k_proj.bias": "model-00001-of-00004.safetensors",
|
192 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
193 |
+
"model.layers.7.self_attn.v_proj.bias": "model-00001-of-00004.safetensors",
|
194 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
195 |
+
"model.layers.7.self_attn.gate.weight": "model-00001-of-00004.safetensors",
|
196 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
197 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
198 |
+
"model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
199 |
+
"model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
200 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
201 |
+
"model.layers.8.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
202 |
+
"model.layers.8.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
203 |
+
"model.layers.8.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
204 |
+
"model.layers.8.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
205 |
+
"model.layers.8.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
206 |
+
"model.layers.8.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
207 |
+
"model.layers.8.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
208 |
+
"model.layers.8.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
209 |
+
"model.layers.8.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
210 |
+
"model.layers.8.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
211 |
+
"model.layers.8.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
212 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
213 |
+
"model.layers.8.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
214 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
215 |
+
"model.layers.8.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
216 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
217 |
+
"model.layers.8.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
218 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
219 |
+
"model.layers.8.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
220 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
221 |
+
"model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
222 |
+
"model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
223 |
+
"model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
224 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
225 |
+
"model.layers.9.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
226 |
+
"model.layers.9.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
227 |
+
"model.layers.9.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
228 |
+
"model.layers.9.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
229 |
+
"model.layers.9.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
230 |
+
"model.layers.9.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
231 |
+
"model.layers.9.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
232 |
+
"model.layers.9.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
233 |
+
"model.layers.9.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
234 |
+
"model.layers.9.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
235 |
+
"model.layers.9.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
236 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
237 |
+
"model.layers.9.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
238 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
239 |
+
"model.layers.9.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
240 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
241 |
+
"model.layers.9.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
242 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
243 |
+
"model.layers.9.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
244 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
245 |
+
"model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
246 |
+
"model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
247 |
+
"model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
248 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
249 |
+
"model.layers.10.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
250 |
+
"model.layers.10.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
251 |
+
"model.layers.10.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
252 |
+
"model.layers.10.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
253 |
+
"model.layers.10.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
254 |
+
"model.layers.10.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
255 |
+
"model.layers.10.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
256 |
+
"model.layers.10.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
257 |
+
"model.layers.10.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
258 |
+
"model.layers.10.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
259 |
+
"model.layers.10.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
260 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
261 |
+
"model.layers.10.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
262 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
263 |
+
"model.layers.10.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
264 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
265 |
+
"model.layers.10.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
266 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
267 |
+
"model.layers.10.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
268 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
269 |
+
"model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
270 |
+
"model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
271 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
272 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
273 |
+
"model.layers.11.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
274 |
+
"model.layers.11.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
275 |
+
"model.layers.11.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
276 |
+
"model.layers.11.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
277 |
+
"model.layers.11.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
278 |
+
"model.layers.11.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
279 |
+
"model.layers.11.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
280 |
+
"model.layers.11.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
281 |
+
"model.layers.11.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
282 |
+
"model.layers.11.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
283 |
+
"model.layers.11.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
284 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
285 |
+
"model.layers.11.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
286 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
287 |
+
"model.layers.11.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
288 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
289 |
+
"model.layers.11.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
290 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
291 |
+
"model.layers.11.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
292 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
293 |
+
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
294 |
+
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
295 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
296 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
297 |
+
"model.layers.12.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
298 |
+
"model.layers.12.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
299 |
+
"model.layers.12.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
300 |
+
"model.layers.12.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
301 |
+
"model.layers.12.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
302 |
+
"model.layers.12.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
303 |
+
"model.layers.12.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
304 |
+
"model.layers.12.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
305 |
+
"model.layers.12.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
306 |
+
"model.layers.12.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
307 |
+
"model.layers.12.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
308 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
309 |
+
"model.layers.12.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
310 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
311 |
+
"model.layers.12.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
312 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
313 |
+
"model.layers.12.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
314 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
315 |
+
"model.layers.12.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
316 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
317 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
318 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
319 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
320 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
321 |
+
"model.layers.13.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
322 |
+
"model.layers.13.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
323 |
+
"model.layers.13.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
324 |
+
"model.layers.13.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
325 |
+
"model.layers.13.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
326 |
+
"model.layers.13.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
327 |
+
"model.layers.13.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
328 |
+
"model.layers.13.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
329 |
+
"model.layers.13.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
330 |
+
"model.layers.13.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
331 |
+
"model.layers.13.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
332 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
333 |
+
"model.layers.13.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
334 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
335 |
+
"model.layers.13.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
336 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
337 |
+
"model.layers.13.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
338 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
339 |
+
"model.layers.13.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
340 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
341 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
342 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
343 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
344 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
345 |
+
"model.layers.14.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
346 |
+
"model.layers.14.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
347 |
+
"model.layers.14.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
348 |
+
"model.layers.14.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
349 |
+
"model.layers.14.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
350 |
+
"model.layers.14.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
351 |
+
"model.layers.14.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
352 |
+
"model.layers.14.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
353 |
+
"model.layers.14.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
354 |
+
"model.layers.14.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
355 |
+
"model.layers.14.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
356 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
357 |
+
"model.layers.14.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
358 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
359 |
+
"model.layers.14.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
360 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
361 |
+
"model.layers.14.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
362 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
363 |
+
"model.layers.14.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
364 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
365 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
366 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
367 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
368 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
369 |
+
"model.layers.15.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
370 |
+
"model.layers.15.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
371 |
+
"model.layers.15.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
372 |
+
"model.layers.15.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
373 |
+
"model.layers.15.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
374 |
+
"model.layers.15.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
375 |
+
"model.layers.15.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
376 |
+
"model.layers.15.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
377 |
+
"model.layers.15.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
378 |
+
"model.layers.15.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
379 |
+
"model.layers.15.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
380 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
381 |
+
"model.layers.15.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
382 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
383 |
+
"model.layers.15.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
384 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
385 |
+
"model.layers.15.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
386 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
387 |
+
"model.layers.15.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
388 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
389 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
390 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
391 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
392 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
393 |
+
"model.layers.16.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
394 |
+
"model.layers.16.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
395 |
+
"model.layers.16.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
396 |
+
"model.layers.16.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
397 |
+
"model.layers.16.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
398 |
+
"model.layers.16.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
399 |
+
"model.layers.16.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
400 |
+
"model.layers.16.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
401 |
+
"model.layers.16.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
402 |
+
"model.layers.16.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
403 |
+
"model.layers.16.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
404 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
405 |
+
"model.layers.16.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
406 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
407 |
+
"model.layers.16.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
408 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
409 |
+
"model.layers.16.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
410 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
411 |
+
"model.layers.16.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
412 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
413 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
414 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
415 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
416 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
417 |
+
"model.layers.17.self_attn.time_maa_x": "model-00002-of-00004.safetensors",
|
418 |
+
"model.layers.17.self_attn.time_maa_r": "model-00002-of-00004.safetensors",
|
419 |
+
"model.layers.17.self_attn.time_maa_k": "model-00002-of-00004.safetensors",
|
420 |
+
"model.layers.17.self_attn.time_maa_v": "model-00002-of-00004.safetensors",
|
421 |
+
"model.layers.17.self_attn.time_maa_w": "model-00002-of-00004.safetensors",
|
422 |
+
"model.layers.17.self_attn.time_maa_g": "model-00002-of-00004.safetensors",
|
423 |
+
"model.layers.17.self_attn.time_maa_w2": "model-00002-of-00004.safetensors",
|
424 |
+
"model.layers.17.self_attn.time_maa_w1": "model-00002-of-00004.safetensors",
|
425 |
+
"model.layers.17.self_attn.time_decay": "model-00002-of-00004.safetensors",
|
426 |
+
"model.layers.17.self_attn.time_decay_w1": "model-00002-of-00004.safetensors",
|
427 |
+
"model.layers.17.self_attn.time_decay_w2": "model-00002-of-00004.safetensors",
|
428 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
429 |
+
"model.layers.17.self_attn.q_proj.bias": "model-00002-of-00004.safetensors",
|
430 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
431 |
+
"model.layers.17.self_attn.k_proj.bias": "model-00002-of-00004.safetensors",
|
432 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
433 |
+
"model.layers.17.self_attn.v_proj.bias": "model-00002-of-00004.safetensors",
|
434 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
435 |
+
"model.layers.17.self_attn.gate.weight": "model-00002-of-00004.safetensors",
|
436 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
437 |
+
"model.layers.17.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
438 |
+
"model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
439 |
+
"model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
440 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
441 |
+
"model.layers.18.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
442 |
+
"model.layers.18.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
443 |
+
"model.layers.18.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
444 |
+
"model.layers.18.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
445 |
+
"model.layers.18.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
446 |
+
"model.layers.18.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
447 |
+
"model.layers.18.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
448 |
+
"model.layers.18.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
449 |
+
"model.layers.18.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
450 |
+
"model.layers.18.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
451 |
+
"model.layers.18.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
452 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
453 |
+
"model.layers.18.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
454 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
455 |
+
"model.layers.18.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
456 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
457 |
+
"model.layers.18.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
458 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
459 |
+
"model.layers.18.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
460 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
461 |
+
"model.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
462 |
+
"model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
463 |
+
"model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
464 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
465 |
+
"model.layers.19.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
466 |
+
"model.layers.19.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
467 |
+
"model.layers.19.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
468 |
+
"model.layers.19.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
469 |
+
"model.layers.19.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
470 |
+
"model.layers.19.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
471 |
+
"model.layers.19.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
472 |
+
"model.layers.19.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
473 |
+
"model.layers.19.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
474 |
+
"model.layers.19.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
475 |
+
"model.layers.19.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
476 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
477 |
+
"model.layers.19.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
478 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
479 |
+
"model.layers.19.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
480 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
481 |
+
"model.layers.19.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
482 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
483 |
+
"model.layers.19.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
484 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
485 |
+
"model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
486 |
+
"model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
487 |
+
"model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
488 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
489 |
+
"model.layers.20.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
490 |
+
"model.layers.20.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
491 |
+
"model.layers.20.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
492 |
+
"model.layers.20.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
493 |
+
"model.layers.20.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
494 |
+
"model.layers.20.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
495 |
+
"model.layers.20.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
496 |
+
"model.layers.20.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
497 |
+
"model.layers.20.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
498 |
+
"model.layers.20.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
499 |
+
"model.layers.20.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
500 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
501 |
+
"model.layers.20.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
502 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
503 |
+
"model.layers.20.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
504 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
505 |
+
"model.layers.20.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
506 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
507 |
+
"model.layers.20.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
508 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
509 |
+
"model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
510 |
+
"model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
511 |
+
"model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
512 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
513 |
+
"model.layers.21.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
514 |
+
"model.layers.21.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
515 |
+
"model.layers.21.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
516 |
+
"model.layers.21.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
517 |
+
"model.layers.21.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
518 |
+
"model.layers.21.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
519 |
+
"model.layers.21.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
520 |
+
"model.layers.21.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
521 |
+
"model.layers.21.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
522 |
+
"model.layers.21.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
523 |
+
"model.layers.21.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
524 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
525 |
+
"model.layers.21.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
526 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
527 |
+
"model.layers.21.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
528 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
529 |
+
"model.layers.21.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
530 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
531 |
+
"model.layers.21.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
532 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
533 |
+
"model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
534 |
+
"model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
535 |
+
"model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
536 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
537 |
+
"model.layers.22.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
538 |
+
"model.layers.22.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
539 |
+
"model.layers.22.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
540 |
+
"model.layers.22.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
541 |
+
"model.layers.22.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
542 |
+
"model.layers.22.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
543 |
+
"model.layers.22.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
544 |
+
"model.layers.22.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
545 |
+
"model.layers.22.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
546 |
+
"model.layers.22.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
547 |
+
"model.layers.22.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
548 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
549 |
+
"model.layers.22.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
550 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
551 |
+
"model.layers.22.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
552 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
553 |
+
"model.layers.22.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
554 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
555 |
+
"model.layers.22.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
556 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
557 |
+
"model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
558 |
+
"model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
559 |
+
"model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
560 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
561 |
+
"model.layers.23.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
562 |
+
"model.layers.23.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
563 |
+
"model.layers.23.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
564 |
+
"model.layers.23.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
565 |
+
"model.layers.23.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
566 |
+
"model.layers.23.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
567 |
+
"model.layers.23.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
568 |
+
"model.layers.23.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
569 |
+
"model.layers.23.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
570 |
+
"model.layers.23.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
571 |
+
"model.layers.23.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
572 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
573 |
+
"model.layers.23.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
574 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
575 |
+
"model.layers.23.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
576 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
577 |
+
"model.layers.23.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
578 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
579 |
+
"model.layers.23.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
580 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
581 |
+
"model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
582 |
+
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
583 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
584 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
585 |
+
"model.layers.24.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
586 |
+
"model.layers.24.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
587 |
+
"model.layers.24.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
588 |
+
"model.layers.24.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
589 |
+
"model.layers.24.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
590 |
+
"model.layers.24.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
591 |
+
"model.layers.24.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
592 |
+
"model.layers.24.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
593 |
+
"model.layers.24.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
594 |
+
"model.layers.24.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
595 |
+
"model.layers.24.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
596 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
597 |
+
"model.layers.24.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
598 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
599 |
+
"model.layers.24.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
600 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
601 |
+
"model.layers.24.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
602 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
603 |
+
"model.layers.24.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
604 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
605 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
606 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
607 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
608 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
609 |
+
"model.layers.25.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
610 |
+
"model.layers.25.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
611 |
+
"model.layers.25.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
612 |
+
"model.layers.25.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
613 |
+
"model.layers.25.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
614 |
+
"model.layers.25.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
615 |
+
"model.layers.25.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
616 |
+
"model.layers.25.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
617 |
+
"model.layers.25.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
618 |
+
"model.layers.25.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
619 |
+
"model.layers.25.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
620 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
621 |
+
"model.layers.25.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
622 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
623 |
+
"model.layers.25.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
624 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
625 |
+
"model.layers.25.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
626 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
627 |
+
"model.layers.25.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
628 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
629 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
630 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
631 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
632 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
633 |
+
"model.layers.26.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
634 |
+
"model.layers.26.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
635 |
+
"model.layers.26.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
636 |
+
"model.layers.26.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
637 |
+
"model.layers.26.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
638 |
+
"model.layers.26.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
639 |
+
"model.layers.26.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
640 |
+
"model.layers.26.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
641 |
+
"model.layers.26.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
642 |
+
"model.layers.26.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
643 |
+
"model.layers.26.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
644 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
645 |
+
"model.layers.26.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
646 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
647 |
+
"model.layers.26.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
648 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
649 |
+
"model.layers.26.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
650 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
651 |
+
"model.layers.26.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
652 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
653 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
654 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
655 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
656 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
657 |
+
"model.layers.27.self_attn.time_maa_x": "model-00003-of-00004.safetensors",
|
658 |
+
"model.layers.27.self_attn.time_maa_r": "model-00003-of-00004.safetensors",
|
659 |
+
"model.layers.27.self_attn.time_maa_k": "model-00003-of-00004.safetensors",
|
660 |
+
"model.layers.27.self_attn.time_maa_v": "model-00003-of-00004.safetensors",
|
661 |
+
"model.layers.27.self_attn.time_maa_w": "model-00003-of-00004.safetensors",
|
662 |
+
"model.layers.27.self_attn.time_maa_g": "model-00003-of-00004.safetensors",
|
663 |
+
"model.layers.27.self_attn.time_maa_w2": "model-00003-of-00004.safetensors",
|
664 |
+
"model.layers.27.self_attn.time_maa_w1": "model-00003-of-00004.safetensors",
|
665 |
+
"model.layers.27.self_attn.time_decay": "model-00003-of-00004.safetensors",
|
666 |
+
"model.layers.27.self_attn.time_decay_w1": "model-00003-of-00004.safetensors",
|
667 |
+
"model.layers.27.self_attn.time_decay_w2": "model-00003-of-00004.safetensors",
|
668 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
669 |
+
"model.layers.27.self_attn.q_proj.bias": "model-00003-of-00004.safetensors",
|
670 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
671 |
+
"model.layers.27.self_attn.k_proj.bias": "model-00003-of-00004.safetensors",
|
672 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
673 |
+
"model.layers.27.self_attn.v_proj.bias": "model-00003-of-00004.safetensors",
|
674 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
675 |
+
"model.layers.27.self_attn.gate.weight": "model-00003-of-00004.safetensors",
|
676 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
677 |
+
"model.layers.27.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
678 |
+
"model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
679 |
+
"model.norm.weight": "model-00004-of-00004.safetensors",
|
680 |
+
"lm_head.weight": "model-00004-of-00004.safetensors"
|
681 |
+
}
|
682 |
+
}
|
modeling_rwkv6qwen2.py
ADDED
@@ -0,0 +1,1336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5 |
+
# and OPT implementations in this library. It has been modified from its
|
6 |
+
# original forms to accommodate minor architectural differences compared
|
7 |
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8 |
+
#
|
9 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
# you may not use this file except in compliance with the License.
|
11 |
+
# You may obtain a copy of the License at
|
12 |
+
#
|
13 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
#
|
15 |
+
# Unless required by applicable law or agreed to in writing, software
|
16 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
# See the License for the specific language governing permissions and
|
19 |
+
# limitations under the License.
|
20 |
+
"""PyTorch RWKV6Qwen2 model."""
|
21 |
+
|
22 |
+
import math
|
23 |
+
import inspect
|
24 |
+
from typing import List, Optional, Tuple, Union, Dict, Any
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.utils.checkpoint
|
28 |
+
from torch import nn
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
31 |
+
|
32 |
+
from transformers.cache_utils import Cache, StaticCache, DynamicCache
|
33 |
+
from transformers.generation import GenerationMixin
|
34 |
+
from transformers.modeling_outputs import (
|
35 |
+
BaseModelOutputWithPast,
|
36 |
+
CausalLMOutputWithPast,
|
37 |
+
QuestionAnsweringModelOutput,
|
38 |
+
SequenceClassifierOutputWithPast,
|
39 |
+
TokenClassifierOutput,
|
40 |
+
)
|
41 |
+
from transformers.modeling_utils import PreTrainedModel
|
42 |
+
from transformers.utils import (
|
43 |
+
add_code_sample_docstrings,
|
44 |
+
add_start_docstrings,
|
45 |
+
add_start_docstrings_to_model_forward,
|
46 |
+
is_flash_attn_2_available,
|
47 |
+
is_flash_attn_greater_or_equal_2_10,
|
48 |
+
logging,
|
49 |
+
replace_return_docstrings,
|
50 |
+
)
|
51 |
+
from .configuration_rwkv6qwen2 import RWKV6Qwen2Config
|
52 |
+
|
53 |
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2RMSNorm, repeat_kv
|
54 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
55 |
+
|
56 |
+
logger = logging.get_logger(__name__)
|
57 |
+
|
58 |
+
|
59 |
+
_CHECKPOINT_FOR_DOC = "RWKV/RWKV6Qwen2-7B"
|
60 |
+
_CONFIG_FOR_DOC = "RWKV6Qwen2Config"
|
61 |
+
|
62 |
+
class RWKV6State(Cache):
|
63 |
+
def __init__(self) -> None:
|
64 |
+
super().__init__()
|
65 |
+
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
66 |
+
self.layer_kv_states: List[torch.Tensor] = []
|
67 |
+
self.layer_shift_states: List[torch.Tensor] = []
|
68 |
+
|
69 |
+
def __getitem__(self, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
70 |
+
"""
|
71 |
+
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
72 |
+
sequence length.
|
73 |
+
"""
|
74 |
+
if layer_idx < len(self):
|
75 |
+
return (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx])
|
76 |
+
else:
|
77 |
+
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
78 |
+
|
79 |
+
def __iter__(self):
|
80 |
+
"""
|
81 |
+
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
|
82 |
+
keys and values
|
83 |
+
"""
|
84 |
+
for layer_idx in range(len(self)):
|
85 |
+
yield (self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx])
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
"""
|
89 |
+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
90 |
+
to the number of layers in the model.
|
91 |
+
"""
|
92 |
+
return len(self.layer_kv_states)
|
93 |
+
|
94 |
+
def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
|
95 |
+
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
96 |
+
# Linear Attention variants do not have a maximum length
|
97 |
+
return new_seq_length
|
98 |
+
|
99 |
+
def reorder_cache(self, beam_idx: torch.LongTensor):
|
100 |
+
"""Reorders the cache for beam search, given the selected beam indices."""
|
101 |
+
raise NotImplementedError('Cannot reorder Linear Attention state')
|
102 |
+
|
103 |
+
def get_seq_length(self, layer_idx: int = 0) -> int:
|
104 |
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
105 |
+
return self._seen_tokens
|
106 |
+
|
107 |
+
def get_max_cache_shape(self) -> Optional[int]:
|
108 |
+
"""Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
|
109 |
+
return None
|
110 |
+
|
111 |
+
def get_max_length(self) -> Optional[int]:
|
112 |
+
"""
|
113 |
+
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
|
114 |
+
"""
|
115 |
+
return None
|
116 |
+
|
117 |
+
# def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
118 |
+
# """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
|
119 |
+
# backward compatibility."""
|
120 |
+
# legacy_cache = ()
|
121 |
+
# for layer_idx in range(len(self)):
|
122 |
+
# legacy_cache += ((self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]),)
|
123 |
+
# return legacy_cache
|
124 |
+
|
125 |
+
# @classmethod
|
126 |
+
# #@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
127 |
+
# def from_legacy_cache(
|
128 |
+
# cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor, torch.FloatTensor]]] = None, num_hidden_layers: int | None = None
|
129 |
+
# ) -> "RWKV6State":
|
130 |
+
# """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
|
131 |
+
# backward compatibility."""
|
132 |
+
# cache = cls()
|
133 |
+
# if past_key_values is not None:
|
134 |
+
# for layer_idx in range(len(past_key_values)):
|
135 |
+
# layer_kv_state, layer_shift_state = past_key_values[layer_idx]
|
136 |
+
# cache.update(layer_kv_state, layer_shift_state, layer_idx)
|
137 |
+
# return cache
|
138 |
+
|
139 |
+
def crop(self, max_length: int):
|
140 |
+
# can't implement this for linear attention variants
|
141 |
+
return
|
142 |
+
|
143 |
+
@torch.no_grad
|
144 |
+
def update(
|
145 |
+
self,
|
146 |
+
kv_state: torch.Tensor,
|
147 |
+
shift_state: torch.Tensor,
|
148 |
+
token_count: int,
|
149 |
+
layer_idx: int,
|
150 |
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
151 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
152 |
+
# Update the number of seen tokens
|
153 |
+
if layer_idx == 0:
|
154 |
+
self._seen_tokens += token_count
|
155 |
+
|
156 |
+
# Update the cache
|
157 |
+
# There may be skipped layers, fill them with empty lists
|
158 |
+
for _ in range(len(self.layer_kv_states), layer_idx + 1):
|
159 |
+
self.layer_kv_states.append(torch.zeros_like(kv_state).requires_grad_(False))
|
160 |
+
self.layer_shift_states.append(torch.zeros_like(shift_state).requires_grad_(False))
|
161 |
+
self.layer_kv_states[layer_idx].copy_(kv_state)
|
162 |
+
self.layer_shift_states[layer_idx].copy_(shift_state)
|
163 |
+
|
164 |
+
return self.layer_kv_states[layer_idx], self.layer_shift_states[layer_idx]
|
165 |
+
|
166 |
+
# @deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
167 |
+
# def batch_split(
|
168 |
+
# self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
|
169 |
+
# ) -> List["DynamicCache"]:
|
170 |
+
# """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
171 |
+
# `_split_model_inputs()` in `generation.utils`"""
|
172 |
+
# out = []
|
173 |
+
# for i in range(0, full_batch_size, split_size):
|
174 |
+
# current_split = DynamicCache()
|
175 |
+
# current_split._seen_tokens = self._seen_tokens
|
176 |
+
# current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
|
177 |
+
# current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
|
178 |
+
# out.append(current_split)
|
179 |
+
# return out
|
180 |
+
|
181 |
+
# @classmethod
|
182 |
+
# @deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
183 |
+
# def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
|
184 |
+
# """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
185 |
+
# `generation.utils`"""
|
186 |
+
# cache = cls()
|
187 |
+
# for idx in range(len(splits[0])):
|
188 |
+
# key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
189 |
+
# value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
190 |
+
# if key_cache != []:
|
191 |
+
# layer_keys = torch.cat(key_cache, dim=0)
|
192 |
+
# layer_values = torch.cat(value_cache, dim=0)
|
193 |
+
# cache.update(layer_keys, layer_values, idx)
|
194 |
+
# return cache
|
195 |
+
|
196 |
+
# def batch_repeat_interleave(self, repeats: int):
|
197 |
+
# """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
|
198 |
+
# for layer_idx in range(len(self)):
|
199 |
+
# self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
200 |
+
# self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
201 |
+
|
202 |
+
# def batch_select_indices(self, indices: torch.Tensor):
|
203 |
+
# """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
|
204 |
+
# for layer_idx in range(len(self)):
|
205 |
+
# self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
206 |
+
# self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
207 |
+
|
208 |
+
try:
|
209 |
+
#from fla.ops.gla.chunk import chunk_gla
|
210 |
+
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
211 |
+
except ImportError:
|
212 |
+
print("Required module is not installed. Please install it using the following commands:")
|
213 |
+
print("pip install -U git+https://github.com/fla-org/flash-linear-attention")
|
214 |
+
print("Additionally, ensure you have at least version 2.2.0 of Triton installed:")
|
215 |
+
print("pip install triton>=2.2.0")
|
216 |
+
|
217 |
+
class Qwen2RotaryEmbedding(nn.Module):
|
218 |
+
def __init__(self, config: RWKV6Qwen2Config, device=None):
|
219 |
+
super().__init__()
|
220 |
+
# BC: "rope_type" was originally "type"
|
221 |
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
222 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
223 |
+
else:
|
224 |
+
self.rope_type = "default"
|
225 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
226 |
+
self.original_max_seq_len = config.max_position_embeddings
|
227 |
+
|
228 |
+
self.config = config
|
229 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
230 |
+
|
231 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
232 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
233 |
+
self.original_inv_freq = self.inv_freq
|
234 |
+
|
235 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
236 |
+
"""
|
237 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
238 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
239 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
240 |
+
"""
|
241 |
+
seq_len = torch.max(position_ids) + 1
|
242 |
+
if seq_len > self.max_seq_len_cached: # growth
|
243 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
|
244 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
245 |
+
self.max_seq_len_cached = seq_len
|
246 |
+
|
247 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
248 |
+
# This .to() is needed if the model has been moved to a device after being initialized (because
|
249 |
+
# the buffer is automatically moved, but not the original copy)
|
250 |
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
251 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
252 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
253 |
+
|
254 |
+
@torch.no_grad()
|
255 |
+
def forward(self, x, position_ids):
|
256 |
+
if "dynamic" in self.rope_type:
|
257 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
258 |
+
|
259 |
+
# Core RoPE block
|
260 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
261 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
262 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
263 |
+
device_type = x.device.type
|
264 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
265 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
266 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
267 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
268 |
+
cos = emb.cos()
|
269 |
+
sin = emb.sin()
|
270 |
+
|
271 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
272 |
+
cos = cos * self.attention_scaling
|
273 |
+
sin = sin * self.attention_scaling
|
274 |
+
|
275 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
276 |
+
|
277 |
+
def generate_rotary_embedding(max_seqlen:int, dim:int, theta:float = 10000.0, scale:float = 1):
|
278 |
+
#inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float).to(device) / dim))
|
279 |
+
|
280 |
+
angular_velocity = theta ** -(torch.arange(0, dim, 2, dtype=torch.float) / dim) / scale # frequencies from 1.0 ... 1/theta
|
281 |
+
angles = torch.outer(torch.arange(max_seqlen), angular_velocity)
|
282 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
283 |
+
emb = torch.cat((angles, angles), dim=-1)
|
284 |
+
return torch.stack([emb.cos(), emb.sin()], dim=0)
|
285 |
+
#return torch.polar(torch.ones_like(angles), angles)
|
286 |
+
|
287 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
288 |
+
def rotate_half(x):
|
289 |
+
"""Rotates half the hidden dims of the input."""
|
290 |
+
x1 = x[..., : x.shape[-1] // 2]
|
291 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
292 |
+
return torch.cat((-x2, x1), dim=-1)
|
293 |
+
|
294 |
+
# # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
295 |
+
# def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim:int=1):
|
296 |
+
# B, L = q.size(0), q.size(-2)
|
297 |
+
# cos = cos[:L].unsqueeze(0).expand(B,L,-1).unsqueeze(unsqueeze_dim)
|
298 |
+
# sin = sin[:L].unsqueeze(0).expand(B,L,-1).unsqueeze(unsqueeze_dim)
|
299 |
+
# q_embed = (q * cos) + (rotate_half(q) * sin)
|
300 |
+
# k_embed = (k * cos) + (rotate_half(k) * sin)
|
301 |
+
# return q_embed, k_embed
|
302 |
+
|
303 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
304 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
q (`torch.Tensor`): The query tensor.
|
308 |
+
k (`torch.Tensor`): The key tensor.
|
309 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
310 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
311 |
+
position_ids (`torch.Tensor`, *optional*):
|
312 |
+
Deprecated and unused.
|
313 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
314 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
315 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
316 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
317 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
318 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
319 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
320 |
+
Returns:
|
321 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
322 |
+
"""
|
323 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
324 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
325 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
326 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
327 |
+
return q_embed, k_embed
|
328 |
+
|
329 |
+
def ortho_init(x, scale):
|
330 |
+
with torch.no_grad():
|
331 |
+
shape = x.shape
|
332 |
+
if len(shape) == 2:
|
333 |
+
gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
|
334 |
+
#nn.init.orthogonal_(x, gain=gain * scale)
|
335 |
+
x.copy_(nn.init.orthogonal_(torch.empty_like(x, dtype=torch.float32), gain=gain * scale))
|
336 |
+
elif len(shape) == 3:
|
337 |
+
gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
|
338 |
+
for i in range(shape[0]):
|
339 |
+
#nn.init.orthogonal_(x[i], gain=gain * scale)
|
340 |
+
x[i].copy_(nn.init.orthogonal_(torch.empty_like(x[i], dtype=torch.float32), gain=gain * scale))
|
341 |
+
else:
|
342 |
+
assert False
|
343 |
+
return x
|
344 |
+
|
345 |
+
class RWKV6Attention(nn.Module):
|
346 |
+
def __init__(self, config, layer_idx: Optional[int] = None):
|
347 |
+
super().__init__()
|
348 |
+
self.config = config
|
349 |
+
self.layer_idx = layer_idx
|
350 |
+
|
351 |
+
if layer_idx is None:
|
352 |
+
logger.warning_once(
|
353 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
354 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
355 |
+
"when creating this class."
|
356 |
+
)
|
357 |
+
|
358 |
+
self.hidden_size = config.hidden_size
|
359 |
+
self.num_heads = config.num_attention_heads
|
360 |
+
self.head_dim = getattr(config, 'head_dim', self.hidden_size // self.num_heads)
|
361 |
+
self.num_key_value_heads = config.num_key_value_heads
|
362 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
363 |
+
self.attention_dropout = config.attention_dropout
|
364 |
+
|
365 |
+
n_layer = self.config.num_hidden_layers
|
366 |
+
n_embd = self.hidden_size
|
367 |
+
dim_att = self.num_heads * self.head_dim
|
368 |
+
layer_id = self.layer_idx
|
369 |
+
|
370 |
+
if self.hidden_size % self.num_heads != 0:
|
371 |
+
raise ValueError(
|
372 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
373 |
+
f" and `num_heads`: {self.num_heads})."
|
374 |
+
)
|
375 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
376 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
377 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
378 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=getattr(config, 'attention_output_bias', config.attention_bias))
|
379 |
+
|
380 |
+
calc_lora_rank = lambda exponent, multiplier: max(1, round(self.hidden_size ** exponent * multiplier / 32)) * 32
|
381 |
+
|
382 |
+
if config.gate_rank_type == 1:
|
383 |
+
self.gate = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
384 |
+
elif config.gate_rank_type == 2:
|
385 |
+
lora_rank_gate = config.lora_rank_gate or calc_lora_rank(0.8, 0.6)
|
386 |
+
self.g1 = nn.Parameter(torch.empty(n_embd, lora_rank_gate))
|
387 |
+
self.g2 = nn.Parameter(torch.empty(lora_rank_gate, n_embd))
|
388 |
+
|
389 |
+
if config.groupnorm_att:
|
390 |
+
self.ln_x = nn.GroupNorm(self.num_heads, dim_att, eps=self.head_dim * 1e-5)
|
391 |
+
|
392 |
+
with torch.no_grad():
|
393 |
+
if config.gate_rank_type == 1:
|
394 |
+
self.gate.weight.zero_()
|
395 |
+
elif config.gate_rank_type == 2:
|
396 |
+
self.g1.zero_()
|
397 |
+
ortho_init(self.g2, 0.1)
|
398 |
+
|
399 |
+
ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
|
400 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
|
401 |
+
|
402 |
+
if self.config.use_tokenshift:
|
403 |
+
ddd = torch.ones(1, 1, n_embd)
|
404 |
+
for i in range(n_embd):
|
405 |
+
ddd[0, 0, i] = i / n_embd
|
406 |
+
|
407 |
+
ddd = torch.zeros(1, 1, n_embd)
|
408 |
+
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
409 |
+
self.time_maa_r = nn.Parameter(torch.zeros_like(ddd))
|
410 |
+
self.time_maa_k = nn.Parameter(torch.zeros_like(ddd))
|
411 |
+
self.time_maa_v = nn.Parameter(torch.zeros_like(ddd))
|
412 |
+
self.time_maa_w = nn.Parameter(torch.zeros_like(ddd))
|
413 |
+
self.time_maa_g = nn.Parameter(torch.zeros_like(ddd))
|
414 |
+
|
415 |
+
lora_rank_tokenshift = config.lora_rank_tokenshift or (32 if n_embd < 4096 else 64)
|
416 |
+
|
417 |
+
self.time_maa_w2 = nn.Parameter(torch.zeros(5, lora_rank_tokenshift, n_embd).uniform_(-0.01, 0.01))
|
418 |
+
self.time_maa_w1 = nn.Parameter(torch.zeros(n_embd, lora_rank_tokenshift*self.time_maa_w2.size(0)))
|
419 |
+
|
420 |
+
lora_rank_decay = config.lora_rank_decay or (64 if n_embd < 4096 else 128)
|
421 |
+
|
422 |
+
# RWKV-6
|
423 |
+
decay_speed = torch.ones(dim_att)
|
424 |
+
for n in range(dim_att):
|
425 |
+
decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
426 |
+
self.time_decay = nn.Parameter(decay_speed.reshape(1,1,dim_att))
|
427 |
+
self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, lora_rank_decay))
|
428 |
+
self.time_decay_w2 = nn.Parameter(torch.zeros(lora_rank_decay, dim_att).uniform_(-0.01, 0.01))
|
429 |
+
|
430 |
+
def forward(
|
431 |
+
self,
|
432 |
+
hidden_states: torch.Tensor,
|
433 |
+
attention_mask: Optional[torch.Tensor] = None,
|
434 |
+
position_ids: Optional[torch.LongTensor] = None,
|
435 |
+
past_key_values: Optional[RWKV6State] = None,
|
436 |
+
output_attentions: bool = False,
|
437 |
+
use_cache: bool = False,
|
438 |
+
cache_position: Optional[torch.LongTensor] = None,
|
439 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
440 |
+
):
|
441 |
+
output_shift_state = hidden_states[:, -1:].detach().clone()
|
442 |
+
|
443 |
+
bsz, q_len, hidden_dim = hidden_states.size()
|
444 |
+
H = self.num_heads
|
445 |
+
|
446 |
+
x = hidden_states
|
447 |
+
|
448 |
+
if use_cache and past_key_values is not None and len(past_key_values) > self.layer_idx:
|
449 |
+
input_kv_state, input_shift_state = past_key_values[self.layer_idx]
|
450 |
+
xprev = torch.cat([input_shift_state, x[:, :-1]], dim=1)
|
451 |
+
else:
|
452 |
+
input_kv_state = None
|
453 |
+
xprev = F.pad(x, (0, 0, 1, -1))
|
454 |
+
|
455 |
+
if self.config.use_tokenshift:
|
456 |
+
dxprev = xprev - x
|
457 |
+
|
458 |
+
xxx = x + dxprev * self.time_maa_x
|
459 |
+
xxx = torch.tanh(xxx @ self.time_maa_w1).view(bsz*q_len, self.time_maa_w2.size(0), -1).transpose(0, 1)
|
460 |
+
xxx = torch.bmm(xxx, self.time_maa_w2).view(self.time_maa_w2.size(0), bsz, q_len, hidden_dim)
|
461 |
+
|
462 |
+
mr, mk, mv, mw, mg = xxx.unbind(dim=0)
|
463 |
+
xr = x + dxprev * (self.time_maa_r + mr)
|
464 |
+
xk = x + dxprev * (self.time_maa_k + mk)
|
465 |
+
xv = x + dxprev * (self.time_maa_v + mv)
|
466 |
+
xw = x + dxprev * (self.time_maa_w + mw)
|
467 |
+
xg = x + dxprev * (self.time_maa_g + mg)
|
468 |
+
else:
|
469 |
+
xr = xk = xv = xw = xg = x
|
470 |
+
|
471 |
+
query_states = self.q_proj(xr)
|
472 |
+
key_states = self.k_proj(xk)
|
473 |
+
value_states = self.v_proj(xv)
|
474 |
+
decay_states = (self.time_decay + torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2).to(query_states.dtype)
|
475 |
+
if self.config.gate_rank_type == 1:
|
476 |
+
gate_states = torch.sigmoid(self.gate(xg))
|
477 |
+
elif self.config.gate_rank_type == 2:
|
478 |
+
gate_states = torch.sigmoid(xg @ self.g1) @ self.g2
|
479 |
+
|
480 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
481 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
482 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
483 |
+
decay_states = decay_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
484 |
+
|
485 |
+
if position_embeddings is not None:
|
486 |
+
cos, sin = position_embeddings
|
487 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
|
488 |
+
|
489 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
490 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
491 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
492 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
493 |
+
|
494 |
+
decay_states_log = -decay_states.float().exp()
|
495 |
+
decay_states_log = decay_states_log.clamp(-5) # FIXME - is this necessary?
|
496 |
+
if self.config.balance_state:
|
497 |
+
key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
|
498 |
+
|
499 |
+
# dealing with left-padding
|
500 |
+
if attention_mask is not None:
|
501 |
+
value_states = value_states * attention_mask[:, None, -value_states.shape[-2]:, None]
|
502 |
+
|
503 |
+
query_states = query_states.to(value_states.dtype)
|
504 |
+
key_states = key_states.to(value_states.dtype)
|
505 |
+
|
506 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
507 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
508 |
+
# cast them back in float16 just to be sure everything works as expected.
|
509 |
+
input_dtype = query_states.dtype
|
510 |
+
if input_dtype == torch.float32:
|
511 |
+
if torch.is_autocast_enabled():
|
512 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
513 |
+
# Handle the case where the model is quantized
|
514 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
515 |
+
target_dtype = self.config._pre_quantization_dtype
|
516 |
+
else:
|
517 |
+
target_dtype = self.q_proj.weight.dtype
|
518 |
+
|
519 |
+
logger.warning_once(
|
520 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
521 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
522 |
+
f" {target_dtype}."
|
523 |
+
)
|
524 |
+
|
525 |
+
query_states = query_states.to(target_dtype)
|
526 |
+
key_states = key_states.to(target_dtype)
|
527 |
+
value_states = value_states.to(target_dtype)
|
528 |
+
|
529 |
+
attn_weights = torch.empty(0, device=x.device)
|
530 |
+
|
531 |
+
scale = query_states.shape[-1] ** -0.5
|
532 |
+
output_final_state = not self.training and use_cache and past_key_values is not None
|
533 |
+
#attn_output, output_kv_state = ChunkGLAFunction.apply(query_states, key_states, value_states, decay_states_log.float(), scale, input_kv_state, output_final_state)
|
534 |
+
#attn_output, output_kv_state = chunk_gla(query_states, key_states, value_states, decay_states_log, scale, input_kv_state, output_final_state)
|
535 |
+
attn_output, output_kv_state = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log, None, scale, input_kv_state, output_final_state)
|
536 |
+
|
537 |
+
if output_final_state:
|
538 |
+
past_key_values.update(output_kv_state, output_shift_state, q_len, self.layer_idx)
|
539 |
+
|
540 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
541 |
+
attn_output = attn_output.view(bsz, q_len, -1)
|
542 |
+
if self.config.groupnorm_att:
|
543 |
+
attn_output = self.ln_x(attn_output.view(bsz * q_len, -1)).view(bsz, q_len, -1)
|
544 |
+
if self.config.gate_rank_type != 0:
|
545 |
+
attn_output = attn_output * gate_states
|
546 |
+
attn_output = self.o_proj(attn_output)
|
547 |
+
|
548 |
+
return attn_output, attn_weights
|
549 |
+
|
550 |
+
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
551 |
+
def __init__(self, config: RWKV6Qwen2Config, layer_idx: int):
|
552 |
+
nn.Module.__init__(self)
|
553 |
+
self.hidden_size = config.hidden_size
|
554 |
+
|
555 |
+
self.self_attn = RWKV6Attention(config, layer_idx) #QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
|
556 |
+
|
557 |
+
self.mlp = Qwen2MLP(config)
|
558 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
559 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
560 |
+
|
561 |
+
def forward(
|
562 |
+
self,
|
563 |
+
hidden_states: torch.Tensor,
|
564 |
+
attention_mask: Optional[torch.Tensor] = None,
|
565 |
+
position_ids: Optional[torch.LongTensor] = None,
|
566 |
+
past_key_values: Optional[Cache] = None,
|
567 |
+
output_attentions: Optional[bool] = False,
|
568 |
+
use_cache: Optional[bool] = False,
|
569 |
+
cache_position: Optional[torch.LongTensor] = None,
|
570 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
571 |
+
**kwargs,
|
572 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
573 |
+
residual = hidden_states
|
574 |
+
|
575 |
+
hidden_states = self.input_layernorm(hidden_states)
|
576 |
+
|
577 |
+
# Self Attention
|
578 |
+
hidden_states, self_attn_weights = self.self_attn(
|
579 |
+
hidden_states=hidden_states,
|
580 |
+
attention_mask=attention_mask,
|
581 |
+
position_ids=position_ids,
|
582 |
+
past_key_values=past_key_values,
|
583 |
+
output_attentions=output_attentions,
|
584 |
+
use_cache=use_cache,
|
585 |
+
cache_position=cache_position,
|
586 |
+
position_embeddings=position_embeddings,
|
587 |
+
**kwargs,
|
588 |
+
)
|
589 |
+
hidden_states = residual + hidden_states
|
590 |
+
|
591 |
+
# Fully Connected
|
592 |
+
residual = hidden_states
|
593 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
594 |
+
hidden_states = self.mlp(hidden_states)
|
595 |
+
hidden_states = residual + hidden_states
|
596 |
+
|
597 |
+
outputs = (hidden_states,)
|
598 |
+
if output_attentions:
|
599 |
+
outputs += (self_attn_weights,)
|
600 |
+
|
601 |
+
return outputs
|
602 |
+
|
603 |
+
RWKV6QWEN2_START_DOCSTRING = r"""
|
604 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
605 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
606 |
+
etc.)
|
607 |
+
|
608 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
609 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
610 |
+
and behavior.
|
611 |
+
|
612 |
+
Parameters:
|
613 |
+
config ([`RWKV6Qwen2Config`]):
|
614 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
615 |
+
load the weights associated with the model, only the configuration. Check out the
|
616 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
617 |
+
"""
|
618 |
+
|
619 |
+
|
620 |
+
@add_start_docstrings(
|
621 |
+
"The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
|
622 |
+
RWKV6QWEN2_START_DOCSTRING,
|
623 |
+
)
|
624 |
+
class RWKV6Qwen2PreTrainedModel(PreTrainedModel):
|
625 |
+
config_class = RWKV6Qwen2Config
|
626 |
+
base_model_prefix = "model"
|
627 |
+
supports_gradient_checkpointing = True
|
628 |
+
_no_split_modules = ["RWKV6Qwen2DecoderLayer"]
|
629 |
+
_skip_keys_device_placement = "past_key_values"
|
630 |
+
_supports_flash_attn_2 = True
|
631 |
+
_supports_sdpa = True
|
632 |
+
_supports_cache_class = True
|
633 |
+
_supports_quantized_cache = True
|
634 |
+
_supports_static_cache = True
|
635 |
+
|
636 |
+
def _init_weights(self, module):
|
637 |
+
std = self.config.initializer_range
|
638 |
+
if isinstance(module, nn.Linear):
|
639 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
640 |
+
if module.bias is not None:
|
641 |
+
module.bias.data.zero_()
|
642 |
+
elif isinstance(module, nn.Embedding):
|
643 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
644 |
+
if module.padding_idx is not None:
|
645 |
+
module.weight.data[module.padding_idx].zero_()
|
646 |
+
|
647 |
+
|
648 |
+
RWKV6QWEN2_INPUTS_DOCSTRING = r"""
|
649 |
+
Args:
|
650 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
651 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
652 |
+
it.
|
653 |
+
|
654 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
655 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
656 |
+
|
657 |
+
[What are input IDs?](../glossary#input-ids)
|
658 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
659 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
660 |
+
|
661 |
+
- 1 for tokens that are **not masked**,
|
662 |
+
- 0 for tokens that are **masked**.
|
663 |
+
|
664 |
+
[What are attention masks?](../glossary#attention-mask)
|
665 |
+
|
666 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
667 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
668 |
+
|
669 |
+
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
670 |
+
`past_key_values`).
|
671 |
+
|
672 |
+
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
673 |
+
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
674 |
+
information on the default strategy.
|
675 |
+
|
676 |
+
- 1 indicates the head is **not masked**,
|
677 |
+
- 0 indicates the head is **masked**.
|
678 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
679 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
680 |
+
config.n_positions - 1]`.
|
681 |
+
|
682 |
+
[What are position IDs?](../glossary#position-ids)
|
683 |
+
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
684 |
+
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
685 |
+
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
686 |
+
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
687 |
+
|
688 |
+
Two formats are allowed:
|
689 |
+
- a [`~cache_utils.Cache`] instance, see our
|
690 |
+
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
691 |
+
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
692 |
+
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
693 |
+
cache format.
|
694 |
+
|
695 |
+
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
696 |
+
legacy cache format will be returned.
|
697 |
+
|
698 |
+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
699 |
+
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
700 |
+
of shape `(batch_size, sequence_length)`.
|
701 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
702 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
703 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
704 |
+
model's internal embedding lookup matrix.
|
705 |
+
use_cache (`bool`, *optional*):
|
706 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
707 |
+
`past_key_values`).
|
708 |
+
output_attentions (`bool`, *optional*):
|
709 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
710 |
+
tensors for more detail.
|
711 |
+
output_hidden_states (`bool`, *optional*):
|
712 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
713 |
+
more detail.
|
714 |
+
return_dict (`bool`, *optional*):
|
715 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
716 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
717 |
+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
718 |
+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
719 |
+
the complete sequence length.
|
720 |
+
"""
|
721 |
+
|
722 |
+
@add_start_docstrings(
|
723 |
+
"The bare RWKV6Qwen2 Model outputting raw hidden-states without any specific head on top.",
|
724 |
+
RWKV6QWEN2_START_DOCSTRING,
|
725 |
+
)
|
726 |
+
class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
727 |
+
"""
|
728 |
+
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
|
729 |
+
|
730 |
+
Args:
|
731 |
+
config: RWKV6Qwen2Config
|
732 |
+
"""
|
733 |
+
|
734 |
+
def __init__(self, config: RWKV6Qwen2Config):
|
735 |
+
super().__init__(config)
|
736 |
+
self.padding_idx = config.pad_token_id
|
737 |
+
self.vocab_size = config.vocab_size
|
738 |
+
|
739 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
740 |
+
self.layers = nn.ModuleList(
|
741 |
+
[RWKV6Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
742 |
+
)
|
743 |
+
self._attn_implementation = config._attn_implementation
|
744 |
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
745 |
+
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
|
746 |
+
|
747 |
+
self.gradient_checkpointing = False
|
748 |
+
# Initialize weights and apply final processing
|
749 |
+
self.post_init()
|
750 |
+
|
751 |
+
def get_input_embeddings(self):
|
752 |
+
return self.embed_tokens
|
753 |
+
|
754 |
+
def set_input_embeddings(self, value):
|
755 |
+
self.embed_tokens = value
|
756 |
+
|
757 |
+
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
|
758 |
+
def forward(
|
759 |
+
self,
|
760 |
+
input_ids: torch.LongTensor = None,
|
761 |
+
attention_mask: Optional[torch.Tensor] = None,
|
762 |
+
position_ids: Optional[torch.LongTensor] = None,
|
763 |
+
past_key_values: Optional[Cache] = None,
|
764 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
765 |
+
use_cache: Optional[bool] = None,
|
766 |
+
output_attentions: Optional[bool] = None,
|
767 |
+
output_hidden_states: Optional[bool] = None,
|
768 |
+
return_dict: Optional[bool] = None,
|
769 |
+
cache_position: Optional[torch.LongTensor] = None,
|
770 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
771 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
772 |
+
output_hidden_states = (
|
773 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
774 |
+
)
|
775 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
776 |
+
|
777 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
778 |
+
|
779 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
780 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
781 |
+
|
782 |
+
if self.gradient_checkpointing and self.training:
|
783 |
+
if use_cache:
|
784 |
+
logger.warning_once(
|
785 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
786 |
+
)
|
787 |
+
use_cache = False
|
788 |
+
|
789 |
+
# kept for BC (non `Cache` `past_key_values` inputs)
|
790 |
+
#return_legacy_cache = False
|
791 |
+
if use_cache and not isinstance(past_key_values, RWKV6State):
|
792 |
+
#return_legacy_cache = True
|
793 |
+
past_key_values = RWKV6State()
|
794 |
+
# if past_key_values is None:
|
795 |
+
# past_key_values = DynamicCache()
|
796 |
+
# else:
|
797 |
+
# past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
798 |
+
# logger.warning_once(
|
799 |
+
# "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
800 |
+
# "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
801 |
+
# "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
802 |
+
# )
|
803 |
+
|
804 |
+
if inputs_embeds is None:
|
805 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
806 |
+
|
807 |
+
if cache_position is None:
|
808 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
809 |
+
cache_position = torch.arange(
|
810 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
811 |
+
)
|
812 |
+
|
813 |
+
if position_ids is None:
|
814 |
+
position_ids = cache_position.unsqueeze(0)
|
815 |
+
|
816 |
+
# causal_mask = self._update_causal_mask(
|
817 |
+
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
818 |
+
# )
|
819 |
+
|
820 |
+
causal_mask = None
|
821 |
+
|
822 |
+
hidden_states = inputs_embeds
|
823 |
+
|
824 |
+
# create position embeddings to be shared across the decoder layers
|
825 |
+
position_embeddings = None
|
826 |
+
if self.config.use_rope:
|
827 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
828 |
+
|
829 |
+
# decoder layers
|
830 |
+
all_hidden_states = () if output_hidden_states else None
|
831 |
+
all_self_attns = () if output_attentions else None
|
832 |
+
next_decoder_cache = None
|
833 |
+
|
834 |
+
for decoder_layer in self.layers:
|
835 |
+
if output_hidden_states:
|
836 |
+
all_hidden_states += (hidden_states,)
|
837 |
+
|
838 |
+
if self.gradient_checkpointing and self.training:
|
839 |
+
layer_outputs = self._gradient_checkpointing_func(
|
840 |
+
decoder_layer.__call__,
|
841 |
+
hidden_states,
|
842 |
+
causal_mask,
|
843 |
+
position_ids,
|
844 |
+
past_key_values,
|
845 |
+
output_attentions,
|
846 |
+
use_cache,
|
847 |
+
cache_position,
|
848 |
+
position_embeddings,
|
849 |
+
)
|
850 |
+
else:
|
851 |
+
layer_outputs = decoder_layer(
|
852 |
+
hidden_states,
|
853 |
+
attention_mask=attention_mask,
|
854 |
+
position_ids=position_ids,
|
855 |
+
past_key_values=past_key_values,
|
856 |
+
output_attentions=output_attentions,
|
857 |
+
use_cache=use_cache,
|
858 |
+
cache_position=cache_position,
|
859 |
+
position_embeddings=position_embeddings,
|
860 |
+
)
|
861 |
+
|
862 |
+
hidden_states = layer_outputs[0]
|
863 |
+
|
864 |
+
if output_attentions:
|
865 |
+
all_self_attns += (layer_outputs[1],)
|
866 |
+
|
867 |
+
hidden_states = self.norm(hidden_states)
|
868 |
+
|
869 |
+
# add hidden states from the last decoder layer
|
870 |
+
if output_hidden_states:
|
871 |
+
all_hidden_states += (hidden_states,)
|
872 |
+
|
873 |
+
#if return_legacy_cache:
|
874 |
+
# next_cache = next_cache.to_legacy_cache()
|
875 |
+
|
876 |
+
if not return_dict:
|
877 |
+
return tuple(v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None)
|
878 |
+
return BaseModelOutputWithPast(
|
879 |
+
last_hidden_state=hidden_states,
|
880 |
+
past_key_values=past_key_values,
|
881 |
+
hidden_states=all_hidden_states,
|
882 |
+
attentions=all_self_attns,
|
883 |
+
)
|
884 |
+
|
885 |
+
class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
|
886 |
+
_tied_weights_keys = ["lm_head.weight"]
|
887 |
+
|
888 |
+
def __init__(self, config):
|
889 |
+
super().__init__(config)
|
890 |
+
self.model = RWKV6Qwen2Model(config)
|
891 |
+
self.vocab_size = config.vocab_size
|
892 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
893 |
+
|
894 |
+
# Initialize weights and apply final processing
|
895 |
+
self.post_init()
|
896 |
+
|
897 |
+
def get_input_embeddings(self):
|
898 |
+
return self.model.embed_tokens
|
899 |
+
|
900 |
+
def set_input_embeddings(self, value):
|
901 |
+
self.model.embed_tokens = value
|
902 |
+
|
903 |
+
def get_output_embeddings(self):
|
904 |
+
return self.lm_head
|
905 |
+
|
906 |
+
def set_output_embeddings(self, new_embeddings):
|
907 |
+
self.lm_head = new_embeddings
|
908 |
+
|
909 |
+
def set_decoder(self, decoder):
|
910 |
+
self.model = decoder
|
911 |
+
|
912 |
+
def get_decoder(self):
|
913 |
+
return self.model
|
914 |
+
|
915 |
+
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
|
916 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
917 |
+
def forward(
|
918 |
+
self,
|
919 |
+
input_ids: torch.LongTensor = None,
|
920 |
+
attention_mask: Optional[torch.Tensor] = None,
|
921 |
+
position_ids: Optional[torch.LongTensor] = None,
|
922 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
923 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
924 |
+
labels: Optional[torch.LongTensor] = None,
|
925 |
+
use_cache: Optional[bool] = None,
|
926 |
+
output_attentions: Optional[bool] = None,
|
927 |
+
output_hidden_states: Optional[bool] = None,
|
928 |
+
return_dict: Optional[bool] = None,
|
929 |
+
cache_position: Optional[torch.LongTensor] = None,
|
930 |
+
num_logits_to_keep: int = 0,
|
931 |
+
**loss_kwargs,
|
932 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
933 |
+
r"""
|
934 |
+
Args:
|
935 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
936 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
937 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
938 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
939 |
+
|
940 |
+
num_logits_to_keep (`int`, *optional*):
|
941 |
+
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
|
942 |
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
943 |
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
944 |
+
|
945 |
+
Returns:
|
946 |
+
|
947 |
+
Example:
|
948 |
+
|
949 |
+
```python
|
950 |
+
>>> from transformers import AutoTokenizer, RWKV6Qwen2ForCausalLM
|
951 |
+
|
952 |
+
>>> model = RWKV6Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
953 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
954 |
+
|
955 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
956 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
957 |
+
|
958 |
+
>>> # Generate
|
959 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
960 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
961 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
962 |
+
```"""
|
963 |
+
|
964 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
965 |
+
output_hidden_states = (
|
966 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
967 |
+
)
|
968 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
969 |
+
|
970 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
971 |
+
outputs = self.model(
|
972 |
+
input_ids=input_ids,
|
973 |
+
attention_mask=attention_mask,
|
974 |
+
position_ids=position_ids,
|
975 |
+
past_key_values=past_key_values,
|
976 |
+
inputs_embeds=inputs_embeds,
|
977 |
+
use_cache=use_cache,
|
978 |
+
output_attentions=output_attentions,
|
979 |
+
output_hidden_states=output_hidden_states,
|
980 |
+
return_dict=return_dict,
|
981 |
+
cache_position=cache_position,
|
982 |
+
)
|
983 |
+
|
984 |
+
hidden_states = outputs[0]
|
985 |
+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
986 |
+
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
987 |
+
|
988 |
+
loss = None
|
989 |
+
if labels is not None:
|
990 |
+
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
991 |
+
|
992 |
+
if not return_dict:
|
993 |
+
output = (logits,) + outputs[1:]
|
994 |
+
return (loss,) + output if loss is not None else output
|
995 |
+
|
996 |
+
return CausalLMOutputWithPast(
|
997 |
+
loss=loss,
|
998 |
+
logits=logits,
|
999 |
+
past_key_values=outputs.past_key_values,
|
1000 |
+
hidden_states=outputs.hidden_states,
|
1001 |
+
attentions=outputs.attentions,
|
1002 |
+
)
|
1003 |
+
|
1004 |
+
def prepare_inputs_for_generation(
|
1005 |
+
self,
|
1006 |
+
input_ids: torch.LongTensor,
|
1007 |
+
past_key_values: Optional[Cache] = None,
|
1008 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
1009 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1010 |
+
cache_position: Optional[torch.LongTensor] = None,
|
1011 |
+
**kwargs,
|
1012 |
+
):
|
1013 |
+
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
1014 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
1015 |
+
input_ids = input_ids[:, -1:]
|
1016 |
+
|
1017 |
+
model_inputs = {
|
1018 |
+
'past_key_values': past_key_values,
|
1019 |
+
'attention_mask': attention_mask,
|
1020 |
+
'cache_position': cache_position,
|
1021 |
+
}
|
1022 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1023 |
+
if inputs_embeds is not None and past_key_values is None:
|
1024 |
+
model_inputs['inputs_embeds'] = inputs_embeds
|
1025 |
+
else:
|
1026 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
1027 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
1028 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
1029 |
+
# TODO: use `next_tokens` directly instead.
|
1030 |
+
model_inputs['input_ids'] = input_ids.contiguous()
|
1031 |
+
|
1032 |
+
model_inputs.update(**kwargs)
|
1033 |
+
|
1034 |
+
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
1035 |
+
model_inputs.pop("labels", None)
|
1036 |
+
|
1037 |
+
return model_inputs
|
1038 |
+
|
1039 |
+
@add_start_docstrings(
|
1040 |
+
"""
|
1041 |
+
The RWKV6Qwen2 Model transformer with a sequence classification head on top (linear layer).
|
1042 |
+
|
1043 |
+
[`RWKV6Qwen2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1044 |
+
(e.g. GPT-2) do.
|
1045 |
+
|
1046 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1047 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1048 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1049 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1050 |
+
each row of the batch).
|
1051 |
+
""",
|
1052 |
+
RWKV6QWEN2_START_DOCSTRING,
|
1053 |
+
)
|
1054 |
+
class RWKV6Qwen2ForSequenceClassification(RWKV6Qwen2PreTrainedModel):
|
1055 |
+
def __init__(self, config):
|
1056 |
+
super().__init__(config)
|
1057 |
+
self.num_labels = config.num_labels
|
1058 |
+
self.model = RWKV6Qwen2Model(config)
|
1059 |
+
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
1060 |
+
|
1061 |
+
# Initialize weights and apply final processing
|
1062 |
+
self.post_init()
|
1063 |
+
|
1064 |
+
def get_input_embeddings(self):
|
1065 |
+
return self.model.embed_tokens
|
1066 |
+
|
1067 |
+
def set_input_embeddings(self, value):
|
1068 |
+
self.model.embed_tokens = value
|
1069 |
+
|
1070 |
+
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
|
1071 |
+
def forward(
|
1072 |
+
self,
|
1073 |
+
input_ids: torch.LongTensor = None,
|
1074 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1075 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1076 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1077 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1078 |
+
labels: Optional[torch.LongTensor] = None,
|
1079 |
+
use_cache: Optional[bool] = None,
|
1080 |
+
output_attentions: Optional[bool] = None,
|
1081 |
+
output_hidden_states: Optional[bool] = None,
|
1082 |
+
return_dict: Optional[bool] = None,
|
1083 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1084 |
+
r"""
|
1085 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1086 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1087 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1088 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1089 |
+
"""
|
1090 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1091 |
+
|
1092 |
+
transformer_outputs = self.model(
|
1093 |
+
input_ids,
|
1094 |
+
attention_mask=attention_mask,
|
1095 |
+
position_ids=position_ids,
|
1096 |
+
past_key_values=past_key_values,
|
1097 |
+
inputs_embeds=inputs_embeds,
|
1098 |
+
use_cache=use_cache,
|
1099 |
+
output_attentions=output_attentions,
|
1100 |
+
output_hidden_states=output_hidden_states,
|
1101 |
+
return_dict=return_dict,
|
1102 |
+
)
|
1103 |
+
hidden_states = transformer_outputs[0]
|
1104 |
+
logits = self.score(hidden_states)
|
1105 |
+
|
1106 |
+
if input_ids is not None:
|
1107 |
+
batch_size = input_ids.shape[0]
|
1108 |
+
else:
|
1109 |
+
batch_size = inputs_embeds.shape[0]
|
1110 |
+
|
1111 |
+
if self.config.pad_token_id is None and batch_size != 1:
|
1112 |
+
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
1113 |
+
if self.config.pad_token_id is None:
|
1114 |
+
sequence_lengths = -1
|
1115 |
+
else:
|
1116 |
+
if input_ids is not None:
|
1117 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1118 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1119 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1120 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1121 |
+
else:
|
1122 |
+
sequence_lengths = -1
|
1123 |
+
|
1124 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1125 |
+
|
1126 |
+
loss = None
|
1127 |
+
if labels is not None:
|
1128 |
+
labels = labels.to(logits.device)
|
1129 |
+
if self.config.problem_type is None:
|
1130 |
+
if self.num_labels == 1:
|
1131 |
+
self.config.problem_type = "regression"
|
1132 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1133 |
+
self.config.problem_type = "single_label_classification"
|
1134 |
+
else:
|
1135 |
+
self.config.problem_type = "multi_label_classification"
|
1136 |
+
|
1137 |
+
if self.config.problem_type == "regression":
|
1138 |
+
loss_fct = MSELoss()
|
1139 |
+
if self.num_labels == 1:
|
1140 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1141 |
+
else:
|
1142 |
+
loss = loss_fct(pooled_logits, labels)
|
1143 |
+
elif self.config.problem_type == "single_label_classification":
|
1144 |
+
loss_fct = CrossEntropyLoss()
|
1145 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1146 |
+
elif self.config.problem_type == "multi_label_classification":
|
1147 |
+
loss_fct = BCEWithLogitsLoss()
|
1148 |
+
loss = loss_fct(pooled_logits, labels)
|
1149 |
+
if not return_dict:
|
1150 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1151 |
+
return ((loss,) + output) if loss is not None else output
|
1152 |
+
|
1153 |
+
return SequenceClassifierOutputWithPast(
|
1154 |
+
loss=loss,
|
1155 |
+
logits=pooled_logits,
|
1156 |
+
past_key_values=transformer_outputs.past_key_values,
|
1157 |
+
hidden_states=transformer_outputs.hidden_states,
|
1158 |
+
attentions=transformer_outputs.attentions,
|
1159 |
+
)
|
1160 |
+
|
1161 |
+
|
1162 |
+
@add_start_docstrings(
|
1163 |
+
"""
|
1164 |
+
The RWKV6Qwen2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
|
1165 |
+
output) e.g. for Named-Entity-Recognition (NER) tasks.
|
1166 |
+
""",
|
1167 |
+
RWKV6QWEN2_START_DOCSTRING,
|
1168 |
+
)
|
1169 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->RWKV6Qwen2, LLAMA->RWKV6QWEN2
|
1170 |
+
class RWKV6Qwen2ForTokenClassification(RWKV6Qwen2PreTrainedModel):
|
1171 |
+
def __init__(self, config):
|
1172 |
+
super().__init__(config)
|
1173 |
+
self.num_labels = config.num_labels
|
1174 |
+
self.model = RWKV6Qwen2Model(config)
|
1175 |
+
if getattr(config, "classifier_dropout", None) is not None:
|
1176 |
+
classifier_dropout = config.classifier_dropout
|
1177 |
+
elif getattr(config, "hidden_dropout", None) is not None:
|
1178 |
+
classifier_dropout = config.hidden_dropout
|
1179 |
+
else:
|
1180 |
+
classifier_dropout = 0.1
|
1181 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1182 |
+
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
1183 |
+
|
1184 |
+
# Initialize weights and apply final processing
|
1185 |
+
self.post_init()
|
1186 |
+
|
1187 |
+
def get_input_embeddings(self):
|
1188 |
+
return self.model.embed_tokens
|
1189 |
+
|
1190 |
+
def set_input_embeddings(self, value):
|
1191 |
+
self.model.embed_tokens = value
|
1192 |
+
|
1193 |
+
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
|
1194 |
+
@add_code_sample_docstrings(
|
1195 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1196 |
+
output_type=TokenClassifierOutput,
|
1197 |
+
config_class=_CONFIG_FOR_DOC,
|
1198 |
+
)
|
1199 |
+
def forward(
|
1200 |
+
self,
|
1201 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1202 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1203 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1204 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
1205 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1206 |
+
labels: Optional[torch.LongTensor] = None,
|
1207 |
+
use_cache: Optional[bool] = None,
|
1208 |
+
output_attentions: Optional[bool] = None,
|
1209 |
+
output_hidden_states: Optional[bool] = None,
|
1210 |
+
return_dict: Optional[bool] = None,
|
1211 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
1212 |
+
r"""
|
1213 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1214 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1215 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1216 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1217 |
+
"""
|
1218 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1219 |
+
|
1220 |
+
outputs = self.model(
|
1221 |
+
input_ids,
|
1222 |
+
attention_mask=attention_mask,
|
1223 |
+
position_ids=position_ids,
|
1224 |
+
past_key_values=past_key_values,
|
1225 |
+
inputs_embeds=inputs_embeds,
|
1226 |
+
use_cache=use_cache,
|
1227 |
+
output_attentions=output_attentions,
|
1228 |
+
output_hidden_states=output_hidden_states,
|
1229 |
+
return_dict=return_dict,
|
1230 |
+
)
|
1231 |
+
sequence_output = outputs[0]
|
1232 |
+
sequence_output = self.dropout(sequence_output)
|
1233 |
+
logits = self.score(sequence_output)
|
1234 |
+
|
1235 |
+
loss = None
|
1236 |
+
if labels is not None:
|
1237 |
+
loss = self.loss_function(logits, labels, self.config)
|
1238 |
+
|
1239 |
+
if not return_dict:
|
1240 |
+
output = (logits,) + outputs[2:]
|
1241 |
+
return ((loss,) + output) if loss is not None else output
|
1242 |
+
|
1243 |
+
return TokenClassifierOutput(
|
1244 |
+
loss=loss,
|
1245 |
+
logits=logits,
|
1246 |
+
hidden_states=outputs.hidden_states,
|
1247 |
+
attentions=outputs.attentions,
|
1248 |
+
)
|
1249 |
+
|
1250 |
+
|
1251 |
+
@add_start_docstrings(
|
1252 |
+
"""
|
1253 |
+
The RWKV6Qwen2 Model transformer with a span classification head on top for extractive question-answering tasks like
|
1254 |
+
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
1255 |
+
""",
|
1256 |
+
RWKV6QWEN2_START_DOCSTRING,
|
1257 |
+
)
|
1258 |
+
# Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->RWKV6Qwen2, MISTRAL->RWKV6QWEN2
|
1259 |
+
class RWKV6Qwen2ForQuestionAnswering(RWKV6Qwen2PreTrainedModel):
|
1260 |
+
base_model_prefix = "model"
|
1261 |
+
|
1262 |
+
# Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->RWKV6Qwen2
|
1263 |
+
def __init__(self, config):
|
1264 |
+
super().__init__(config)
|
1265 |
+
self.model = RWKV6Qwen2Model(config)
|
1266 |
+
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
1267 |
+
|
1268 |
+
# Initialize weights and apply final processing
|
1269 |
+
self.post_init()
|
1270 |
+
|
1271 |
+
def get_input_embeddings(self):
|
1272 |
+
return self.model.embed_tokens
|
1273 |
+
|
1274 |
+
def set_input_embeddings(self, value):
|
1275 |
+
self.model.embed_tokens = value
|
1276 |
+
|
1277 |
+
@add_start_docstrings_to_model_forward(RWKV6QWEN2_INPUTS_DOCSTRING)
|
1278 |
+
def forward(
|
1279 |
+
self,
|
1280 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1281 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1282 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1283 |
+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
1284 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1285 |
+
start_positions: Optional[torch.LongTensor] = None,
|
1286 |
+
end_positions: Optional[torch.LongTensor] = None,
|
1287 |
+
output_attentions: Optional[bool] = None,
|
1288 |
+
output_hidden_states: Optional[bool] = None,
|
1289 |
+
return_dict: Optional[bool] = None,
|
1290 |
+
**kwargs,
|
1291 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
1292 |
+
r"""
|
1293 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1294 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1295 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1296 |
+
are not taken into account for computing the loss.
|
1297 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1298 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1299 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1300 |
+
are not taken into account for computing the loss.
|
1301 |
+
"""
|
1302 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1303 |
+
|
1304 |
+
outputs = self.model(
|
1305 |
+
input_ids,
|
1306 |
+
attention_mask=attention_mask,
|
1307 |
+
position_ids=position_ids,
|
1308 |
+
past_key_values=past_key_values,
|
1309 |
+
inputs_embeds=inputs_embeds,
|
1310 |
+
output_attentions=output_attentions,
|
1311 |
+
output_hidden_states=output_hidden_states,
|
1312 |
+
return_dict=return_dict,
|
1313 |
+
)
|
1314 |
+
|
1315 |
+
sequence_output = outputs[0]
|
1316 |
+
|
1317 |
+
logits = self.qa_outputs(sequence_output)
|
1318 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1319 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
1320 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
1321 |
+
|
1322 |
+
loss = None
|
1323 |
+
if start_positions is not None and end_positions is not None:
|
1324 |
+
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
|
1325 |
+
|
1326 |
+
if not return_dict:
|
1327 |
+
output = (start_logits, end_logits) + outputs[2:]
|
1328 |
+
return ((loss,) + output) if loss is not None else output
|
1329 |
+
|
1330 |
+
return QuestionAnsweringModelOutput(
|
1331 |
+
loss=loss,
|
1332 |
+
start_logits=start_logits,
|
1333 |
+
end_logits=end_logits,
|
1334 |
+
hidden_states=outputs.hidden_states,
|
1335 |
+
attentions=outputs.attentions,
|
1336 |
+
)
|
qwen2.py
ADDED
@@ -0,0 +1,670 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, math, gc, importlib.util
|
2 |
+
import torch
|
3 |
+
import torch.utils.checkpoint
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torch import Tensor
|
7 |
+
from typing import Tuple, Optional
|
8 |
+
|
9 |
+
#from src.state import ModelState, BlockState, ChannelMixState, TimeMixState, Shared
|
10 |
+
|
11 |
+
#from configs import TrainerCLI_Config, Model_Config, Transformer_Config, Train_Config
|
12 |
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
13 |
+
|
14 |
+
#from src.rotary import generate_rotary_embedding, generate_binary_rotary_embedding, apply_rotary_embedding
|
15 |
+
|
16 |
+
#from src.CoreDependencies import *
|
17 |
+
|
18 |
+
from dataclasses import dataclass
|
19 |
+
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
if importlib.util.find_spec('deepspeed'):
|
22 |
+
import deepspeed
|
23 |
+
|
24 |
+
from logger import print0 as print
|
25 |
+
|
26 |
+
from fla.ops.gla.chunk import chunk_gla
|
27 |
+
from fla.ops.gla.fused_recurrent import fused_recurrent_gla
|
28 |
+
|
29 |
+
class ModelState:
|
30 |
+
def __init__(self):
|
31 |
+
self.seq_pos = 0
|
32 |
+
self.input_tokens_cache = torch.tensor([])
|
33 |
+
self.k_cache = torch.tensor([])
|
34 |
+
self.block_states:list[BlockState] = []
|
35 |
+
|
36 |
+
class TimeMixState:
|
37 |
+
def __init__(self, wkv_state=torch.tensor([]), shift_state=torch.tensor([])):
|
38 |
+
self.wkv_state = wkv_state
|
39 |
+
self.shift_state = shift_state
|
40 |
+
|
41 |
+
class ChannelMixState:
|
42 |
+
def __init__(self, shift_state=torch.tensor([])):
|
43 |
+
self.shift_state = shift_state
|
44 |
+
|
45 |
+
class BlockState:
|
46 |
+
def __init__(self, time_mix_state: TimeMixState, channel_mix_state: ChannelMixState):
|
47 |
+
self.time_mix_state = time_mix_state
|
48 |
+
self.channel_mix_state = channel_mix_state
|
49 |
+
|
50 |
+
class Shared:
|
51 |
+
def __init__(self):
|
52 |
+
self.angles = torch.tensor([])
|
53 |
+
self.bias_mask = torch.tensor([])
|
54 |
+
|
55 |
+
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
|
56 |
+
class Qwen2RMSNorm(nn.Module):
|
57 |
+
def __init__(self, hidden_size, eps=1e-6):
|
58 |
+
"""
|
59 |
+
Qwen2RMSNorm is equivalent to T5LayerNorm
|
60 |
+
"""
|
61 |
+
super().__init__()
|
62 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
63 |
+
self.variance_epsilon = eps
|
64 |
+
|
65 |
+
def forward(self, hidden_states):
|
66 |
+
input_dtype = hidden_states.dtype
|
67 |
+
hidden_states = hidden_states.to(torch.float32)
|
68 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
69 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
70 |
+
return self.weight * hidden_states.to(input_dtype)
|
71 |
+
|
72 |
+
def extra_repr(self):
|
73 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
74 |
+
|
75 |
+
def generate_rotary_embedding(max_seqlen:int, dim:int, theta:float = 10000.0, scale:float = 1):
|
76 |
+
#inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float).to(device) / dim))
|
77 |
+
|
78 |
+
angular_velocity = theta ** -(torch.arange(0, dim, 2, dtype=torch.float) / dim) / scale # frequencies from 1.0 ... 1/theta
|
79 |
+
angles = torch.outer(torch.arange(max_seqlen), angular_velocity)
|
80 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
81 |
+
emb = torch.cat((angles, angles), dim=-1)
|
82 |
+
return torch.stack([emb.cos(), emb.sin()], dim=0)
|
83 |
+
#return torch.polar(torch.ones_like(angles), angles)
|
84 |
+
|
85 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
86 |
+
def rotate_half(x):
|
87 |
+
"""Rotates half the hidden dims of the input."""
|
88 |
+
x1 = x[..., : x.shape[-1] // 2]
|
89 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
90 |
+
return torch.cat((-x2, x1), dim=-1)
|
91 |
+
|
92 |
+
|
93 |
+
# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
|
94 |
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim:int=1):
|
95 |
+
B, L = q.size(0), q.size(-2)
|
96 |
+
cos = cos[:L].unsqueeze(0).expand(B,L,-1).unsqueeze(unsqueeze_dim)
|
97 |
+
sin = sin[:L].unsqueeze(0).expand(B,L,-1).unsqueeze(unsqueeze_dim)
|
98 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
99 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
100 |
+
return q_embed, k_embed
|
101 |
+
|
102 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
103 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
104 |
+
"""
|
105 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
106 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
107 |
+
"""
|
108 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
109 |
+
if n_rep == 1:
|
110 |
+
return hidden_states
|
111 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
112 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
113 |
+
|
114 |
+
def get_tmix_default_state(x:Tensor, config:Qwen2Config, requires_grad:bool):
|
115 |
+
B, T, C = x.size()
|
116 |
+
return TimeMixState(
|
117 |
+
torch.zeros([B, config.num_attention_heads, config.hidden_size // config.num_attention_heads, config.hidden_size // config.num_attention_heads], dtype=x.dtype, device=x.device, requires_grad=requires_grad),
|
118 |
+
torch.zeros([B, C], dtype=x.dtype, device=x.device, requires_grad=requires_grad)
|
119 |
+
)
|
120 |
+
|
121 |
+
@dataclass
|
122 |
+
class LLMOutput:
|
123 |
+
logits: torch.FloatTensor = None
|
124 |
+
model_state: ModelState = None
|
125 |
+
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
126 |
+
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
127 |
+
post_attention_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
128 |
+
student_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
|
129 |
+
student_post_attention_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
|
130 |
+
|
131 |
+
class TMix_qwen2(nn.Module):
|
132 |
+
def get_default_state_factory(self): return get_tmix_default_state
|
133 |
+
|
134 |
+
def __init__(self, config:Qwen2Config, layer_id):
|
135 |
+
super().__init__()
|
136 |
+
self.config = config
|
137 |
+
self.layer_id = layer_id
|
138 |
+
self.ctx_len = config.max_position_embeddings
|
139 |
+
|
140 |
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
141 |
+
|
142 |
+
self.hidden_size = config.hidden_size
|
143 |
+
self.num_heads = config.num_attention_heads
|
144 |
+
self.num_key_value_heads = config.num_key_value_heads if config.num_key_value_heads > 0 else self.num_heads
|
145 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
146 |
+
# self.max_position_embeddings = config.max_position_embeddings
|
147 |
+
# self.rope_theta = config.rope_theta
|
148 |
+
# self.is_causal = True
|
149 |
+
# self.attention_dropout = config.attention_dropout
|
150 |
+
|
151 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
152 |
+
raise ValueError(
|
153 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
154 |
+
f" and `num_heads`: {self.num_heads})."
|
155 |
+
)
|
156 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
157 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
158 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
159 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
160 |
+
|
161 |
+
# self.rotary_emb = Qwen2RotaryEmbedding(
|
162 |
+
# self.head_dim,
|
163 |
+
# max_position_embeddings=config.rope.max_seqlen,
|
164 |
+
# base=config.rope.base,
|
165 |
+
# )
|
166 |
+
|
167 |
+
def forward(self, x, last_model_state:ModelState, shared:Shared, output_attentions:bool=False):
|
168 |
+
last_state = last_model_state.block_states[self.layer_id].time_mix_state
|
169 |
+
B, L, D = x.size()
|
170 |
+
QH = self.num_heads
|
171 |
+
KVH = self.num_key_value_heads
|
172 |
+
|
173 |
+
q = self.q_proj(x)
|
174 |
+
k = self.k_proj(x)
|
175 |
+
v = self.v_proj(x)
|
176 |
+
|
177 |
+
wkv_state = last_state.wkv_state
|
178 |
+
|
179 |
+
# handle recurrent inference via maintaining a kv cache
|
180 |
+
# if not self.training:
|
181 |
+
# new_kv_cache = torch.stack([k, v], dim=0)
|
182 |
+
# wkv_state = torch.cat([wkv_state, new_kv_cache], dim=-2)
|
183 |
+
# k, v = wkv_state.unbind(0)
|
184 |
+
# k, v = k.contiguous(), v.contiguous()
|
185 |
+
|
186 |
+
is_causal = q.size(1)==k.size(1)
|
187 |
+
|
188 |
+
q = q.view(B,L,QH,-1).transpose(1,2)
|
189 |
+
k = k.view(B,L,KVH,-1).transpose(1,2)
|
190 |
+
v = v.view(B,L,KVH,-1).transpose(1,2)
|
191 |
+
|
192 |
+
#q, k = apply_rotary_embedding(q, k, shared.angles)
|
193 |
+
#kv_seq_len, position_ids = L, torch.arange(L, dtype=torch.int, device=v.device).view(1, L).expand(B, L)
|
194 |
+
#cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)
|
195 |
+
cos, sin = shared.angles.unbind(0)
|
196 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
197 |
+
q = q.to(v.dtype)
|
198 |
+
k = k.to(v.dtype)
|
199 |
+
|
200 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
201 |
+
k = repeat_kv(k, self.num_key_value_groups)
|
202 |
+
v = repeat_kv(v, self.num_key_value_groups)
|
203 |
+
|
204 |
+
if output_attentions:
|
205 |
+
attn_weights = (q * (self.head_dim ** -0.5)) @ k.mT
|
206 |
+
|
207 |
+
#y = nn.functional.softmax(attn_weights + causal_mask, dim=-1, dtype=torch.float32).to(q.dtype) @ v
|
208 |
+
#attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
209 |
+
#y = torch.matmul(attn_weights, v)
|
210 |
+
|
211 |
+
# NOTE - we are outputting the non-softmaxed attention weights, just with exp() maxed to 1.0 since we're comparing against pre-normalized output of linear attention
|
212 |
+
# upcast attention to fp32
|
213 |
+
causal_mask = torch.full([L, L], fill_value=-torch.inf, device=attn_weights.device, dtype=attn_weights.dtype).triu(1)
|
214 |
+
|
215 |
+
attn_weights = nn.functional.softmax(attn_weights + causal_mask, dim=-1, dtype=torch.float32).to(q.dtype)
|
216 |
+
|
217 |
+
#attn_weights = attn_weights.tril()
|
218 |
+
#attn_weights = (attn_weights - attn_weights.max() + causal_mask).exp()
|
219 |
+
#attn_weights = (attn_weights - torch.max(attn_weights, dim=-1, keepdim=True).values + causal_mask).exp()
|
220 |
+
else:
|
221 |
+
attn_weights = torch.empty(0, device=x.device)
|
222 |
+
|
223 |
+
y = nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=is_causal)
|
224 |
+
y = y.transpose(1,2).reshape(B,L,D)
|
225 |
+
y = self.o_proj(y)
|
226 |
+
return y, TimeMixState(wkv_state, last_state.shift_state), attn_weights
|
227 |
+
|
228 |
+
class TMix_qwen2rwkv(TMix_qwen2):
|
229 |
+
"""
|
230 |
+
Qwen2 RWKV-6cSimple attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
|
231 |
+
and adds RWKV specific weights for tokenshift, decay, time_first, and the final layernorm.
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(self, config:Qwen2Config, layer_id):
|
235 |
+
super().__init__(config, layer_id)
|
236 |
+
|
237 |
+
n_layer = config.num_hidden_layers
|
238 |
+
n_embd = self.hidden_size
|
239 |
+
dim_att = self.num_heads * self.head_dim
|
240 |
+
layer_id = self.layer_id
|
241 |
+
|
242 |
+
with torch.no_grad():
|
243 |
+
ratio_0_to_1 = layer_id / (n_layer - 1) # 0 to 1
|
244 |
+
ratio_1_to_almost0 = 1.0 - (layer_id / n_layer) # 1 to ~0
|
245 |
+
ddd = torch.ones(1, 1, n_embd)
|
246 |
+
for i in range(n_embd):
|
247 |
+
ddd[0, 0, i] = i / n_embd
|
248 |
+
|
249 |
+
# self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
250 |
+
# self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
|
251 |
+
# self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
252 |
+
# self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1))
|
253 |
+
# self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
254 |
+
|
255 |
+
ddd = torch.zeros(1, 1, n_embd)
|
256 |
+
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
|
257 |
+
self.time_maa_r = nn.Parameter(torch.zeros_like(ddd))
|
258 |
+
self.time_maa_k = nn.Parameter(torch.zeros_like(ddd))
|
259 |
+
self.time_maa_v = nn.Parameter(torch.zeros_like(ddd))
|
260 |
+
self.time_maa_w = nn.Parameter(torch.zeros_like(ddd))
|
261 |
+
self.time_maa_g = nn.Parameter(torch.zeros_like(ddd))
|
262 |
+
|
263 |
+
D_MIX_LORA = 32 if n_embd < 4096 else 64
|
264 |
+
self.time_maa_w2 = nn.Parameter(torch.zeros(5, D_MIX_LORA, n_embd).uniform_(-0.01, 0.01))
|
265 |
+
self.time_maa_w1 = nn.Parameter(torch.zeros(n_embd, D_MIX_LORA*self.time_maa_w2.size(0)))
|
266 |
+
|
267 |
+
# # per-head RWKV-6
|
268 |
+
# H = self.num_heads
|
269 |
+
# # fancy time_decay
|
270 |
+
# decay_speed = torch.ones(H)
|
271 |
+
# for h in range(H):
|
272 |
+
# decay_speed[h] = -6 + 5 * (h / max(H - 1, 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
273 |
+
# self.time_decay = nn.Parameter(decay_speed)
|
274 |
+
# #self.time_decay = nn.Parameter(torch.empty(H)).uniform_(-8, -7)
|
275 |
+
# D_DECAY_LORA = 64 if n_embd < 4096 else 128
|
276 |
+
# self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, D_DECAY_LORA))
|
277 |
+
# self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, H).uniform_(-0.01, 0.01))
|
278 |
+
|
279 |
+
# RWKV-6
|
280 |
+
decay_speed = torch.ones(dim_att)
|
281 |
+
for n in range(dim_att):
|
282 |
+
decay_speed[n] = -6 + 5 * (n / (dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
|
283 |
+
self.time_decay = nn.Parameter(decay_speed.reshape(1,1,dim_att))
|
284 |
+
D_DECAY_LORA = 64 if n_embd < 4096 else 128
|
285 |
+
self.time_decay_w1 = nn.Parameter(torch.zeros(n_embd, D_DECAY_LORA))
|
286 |
+
self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, dim_att).uniform_(-0.01, 0.01))
|
287 |
+
# tmp = torch.zeros(dim_att)
|
288 |
+
# for n in range(dim_att):
|
289 |
+
# zigzag = ((n + 1) % 3 - 1) * 0.1
|
290 |
+
# tmp[n] = ratio_0_to_1 * (1 - (n / (dim_att - 1))) + zigzag
|
291 |
+
# self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
|
292 |
+
|
293 |
+
self.gate = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
294 |
+
# start gate out with no effect
|
295 |
+
nn.init.zeros_(self.gate.weight)
|
296 |
+
#nn.init.ones_(self.gate.bias)
|
297 |
+
|
298 |
+
#self.ln_x = nn.LayerNorm(dim_att)
|
299 |
+
|
300 |
+
def segsum(self, w_log): # B H L 1
|
301 |
+
w_log_cumsum = torch.cumsum(w_log, dim=-2) # (B, H, L, 1)
|
302 |
+
w_mask = torch.exp((w_log_cumsum - w_log_cumsum.mT).tril()).tril() # (B, H, L, L)
|
303 |
+
return w_mask
|
304 |
+
|
305 |
+
def forward(self, x, last_model_state:ModelState, shared:Shared, output_attentions:bool=False):
|
306 |
+
last_state = last_model_state.block_states[self.layer_id].time_mix_state
|
307 |
+
bsz, q_len, hidden_dim = x.size()
|
308 |
+
|
309 |
+
dxprev = torch.nn.functional.pad(x, (0, 0, 1, -1)) - x
|
310 |
+
|
311 |
+
xxx = x + dxprev * self.time_maa_x
|
312 |
+
xxx = torch.tanh(xxx @ self.time_maa_w1).view(bsz*q_len, self.time_maa_w2.size(0), -1).transpose(0, 1)
|
313 |
+
xxx = torch.bmm(xxx, self.time_maa_w2).view(self.time_maa_w2.size(0), bsz, q_len, hidden_dim)
|
314 |
+
|
315 |
+
mr, mk, mv, mw, mg = xxx.unbind(dim=0)
|
316 |
+
xr = x + dxprev * (self.time_maa_r + mr)
|
317 |
+
xk = x + dxprev * (self.time_maa_k + mk)
|
318 |
+
xv = x + dxprev * (self.time_maa_v + mv)
|
319 |
+
xw = x + dxprev * (self.time_maa_w + mw)
|
320 |
+
xg = x + dxprev * (self.time_maa_g + mg)
|
321 |
+
|
322 |
+
query_states = self.q_proj(xr)
|
323 |
+
key_states = self.k_proj(xk)
|
324 |
+
value_states = self.v_proj(xv)
|
325 |
+
decay_states = (self.time_decay + torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2).to(query_states.dtype)
|
326 |
+
gate_states = torch.sigmoid(self.gate(xg))
|
327 |
+
|
328 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
329 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
330 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
331 |
+
decay_states = decay_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
332 |
+
|
333 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
334 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
335 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
336 |
+
#dropout_rate = 0.0 if not self.training else self.attention_dropout
|
337 |
+
|
338 |
+
decay_states_log = -decay_states.float().exp()
|
339 |
+
decay_states_log = decay_states_log.clamp(-5) # FIXME - is this necessary?
|
340 |
+
key_states = (key_states * (1 - decay_states_log.exp())).to(key_states.dtype)
|
341 |
+
|
342 |
+
query_states = query_states.to(value_states.dtype)
|
343 |
+
key_states = key_states.to(value_states.dtype)
|
344 |
+
|
345 |
+
# # In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
346 |
+
# # therefore the input hidden states gets silently casted in float32. Hence, we need
|
347 |
+
# # cast them back in float16 just to be sure everything works as expected.
|
348 |
+
# input_dtype = query_states.dtype
|
349 |
+
# if input_dtype == torch.float32:
|
350 |
+
# if torch.is_autocast_enabled():
|
351 |
+
# target_dtype = torch.get_autocast_gpu_dtype()
|
352 |
+
# # Handle the case where the model is quantized
|
353 |
+
# elif hasattr(self.config, "_pre_quantization_dtype"):
|
354 |
+
# target_dtype = self.config._pre_quantization_dtype
|
355 |
+
# else:
|
356 |
+
# target_dtype = self.q_proj.weight.dtype
|
357 |
+
|
358 |
+
# logger.warning_once(
|
359 |
+
# f"The input hidden states seems to be silently casted in float32, this might be related to"
|
360 |
+
# f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
361 |
+
# f" {target_dtype}."
|
362 |
+
# )
|
363 |
+
|
364 |
+
# query_states = query_states.to(target_dtype)
|
365 |
+
# key_states = key_states.to(target_dtype)
|
366 |
+
# value_states = value_states.to(target_dtype)
|
367 |
+
|
368 |
+
# decay_states_log.view is to match fla_chunk_simple_gla's requirements
|
369 |
+
#print("layer", self.layer_id, "pre ", bool(query_states.isnan().any()), bool(key_states.isnan().any()), bool(value_states.isnan().any()), bool(decay_states_log.isnan().any()))
|
370 |
+
#o = chunk_simple_gla(q.contiguous(), k.contiguous(), v.contiguous(), g.contiguous(), scale)
|
371 |
+
|
372 |
+
#print("layer", self.layer_id, "post", bool(query_states.isnan().any()), bool(key_states.isnan().any()), bool(value_states.isnan().any()), bool(decay_states_log.isnan().any()))
|
373 |
+
|
374 |
+
if not output_attentions:
|
375 |
+
attn_weights = torch.empty(0, device=x.device)
|
376 |
+
|
377 |
+
#attn_output = fla_chunk_simple_gla(query_states, key_states, value_states, decay_states_log.view(bsz, self.num_heads, q_len))[0]
|
378 |
+
#attn_output = chunk_gla(query_states, key_states, value_states, decay_states_log)[0]
|
379 |
+
attn_output = fused_recurrent_gla(query_states, key_states, value_states, decay_states_log)[0]
|
380 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
381 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
382 |
+
#attn_output = self.ln_x(attn_output)
|
383 |
+
attn_output = self.o_proj(attn_output * gate_states)
|
384 |
+
else:
|
385 |
+
attn_weights = (query_states * (key_states.size(-1) ** -0.5)) @ key_states.mT
|
386 |
+
|
387 |
+
decay_states_log = decay_states_log.mean(-1, keepdim=True)
|
388 |
+
attn_weights = attn_weights.float() * self.segsum(decay_states_log.float()) # NOTE - without the explicit cast to float ddp mismatched deepspeed here
|
389 |
+
|
390 |
+
attn_weights = attn_weights.to(query_states.dtype)
|
391 |
+
attn_output = torch.empty(0, device=x.device)
|
392 |
+
|
393 |
+
return attn_output, TimeMixState(last_state.wkv_state, last_state.shift_state), attn_weights #, past_key_value
|
394 |
+
|
395 |
+
def get_cmix_default_state(x:Tensor, config:Qwen2Config, requires_grad:bool):
|
396 |
+
B, T, C = x.size()
|
397 |
+
return ChannelMixState(
|
398 |
+
torch.zeros([B, C], dtype=x.dtype, device=x.device, requires_grad=requires_grad)
|
399 |
+
)
|
400 |
+
|
401 |
+
class CMix_qwen2(nn.Module):
|
402 |
+
def get_default_state_factory(self): return get_cmix_default_state
|
403 |
+
|
404 |
+
def __init__(self, config:Qwen2Config, layer_id):
|
405 |
+
super().__init__()
|
406 |
+
self.config = config
|
407 |
+
self.layer_id = layer_id
|
408 |
+
|
409 |
+
self.hidden_size = config.hidden_size
|
410 |
+
self.intermediate_size = config.intermediate_size
|
411 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
412 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
413 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
414 |
+
self.act_fn = torch.nn.SiLU() #ACT2FN[config.hidden_act]
|
415 |
+
|
416 |
+
def forward(self, x, last_model_state:ModelState):
|
417 |
+
last_state = last_model_state.block_states[self.layer_id].channel_mix_state
|
418 |
+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)), last_state
|
419 |
+
|
420 |
+
class Qwen2DecoderLayer(nn.Module):
|
421 |
+
def __init__(self, config:Qwen2Config, layer_id:int):
|
422 |
+
super().__init__()
|
423 |
+
|
424 |
+
self.config = config
|
425 |
+
|
426 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
427 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
428 |
+
|
429 |
+
cmix = CMix_qwen2(config, layer_id)
|
430 |
+
|
431 |
+
#if args.attention_type == 'rwkv':
|
432 |
+
self.self_attn = TMix_qwen2rwkv(config, layer_id)
|
433 |
+
#else:
|
434 |
+
# self.self_attn = TMix_qwen2(config, layer_id)
|
435 |
+
self.default_time_mix_state_factory = self.self_attn.get_default_state_factory() if hasattr(self.self_attn, 'get_default_state_factory') else lambda x, c, r: TimeMixState()
|
436 |
+
|
437 |
+
self.teacher_attn = None
|
438 |
+
#if config.train is not None:
|
439 |
+
# if config.train.attention_distillation_stage in (1, 2):
|
440 |
+
# self.teacher_attn = TMix_qwen2(config, layer_id)
|
441 |
+
|
442 |
+
self.default_channel_mix_state_factory = cmix.get_default_state_factory() if hasattr(cmix, 'get_default_state_factory') else lambda x, c, r: ChannelMixState()
|
443 |
+
self.mlp = cmix
|
444 |
+
|
445 |
+
def forward(self, x:Tensor, last_model_state:ModelState, shared:Shared, output_attentions:bool, output_post_attention_hidden_states:bool):
|
446 |
+
s = last_model_state
|
447 |
+
if self.teacher_attn is not None:
|
448 |
+
dx, last_timemix_state, attentions = self.teacher_attn(self.input_layernorm(x), s, shared, output_attentions)
|
449 |
+
student_dx, student_last_timemix_state, student_attentions = self.self_attn(self.input_layernorm(x), s, shared, output_attentions)
|
450 |
+
else:
|
451 |
+
dx, last_timemix_state, attentions = self.self_attn(self.input_layernorm(x), s, shared, output_attentions)
|
452 |
+
student_dx, student_last_timemix_state, student_attentions = None, None, None
|
453 |
+
if output_post_attention_hidden_states:
|
454 |
+
post_attention_hidden_states = dx
|
455 |
+
student_post_attention_hidden_states = student_dx
|
456 |
+
else:
|
457 |
+
post_attention_hidden_states = torch.empty(0, device=x.device)
|
458 |
+
student_post_attention_hidden_states = torch.empty(0, device=x.device)
|
459 |
+
|
460 |
+
x = x + dx
|
461 |
+
dx, last_chanmix_state = self.mlp(self.post_attention_layernorm(x), s)
|
462 |
+
x = x + dx
|
463 |
+
return x, s, attentions, post_attention_hidden_states, student_attentions, student_post_attention_hidden_states
|
464 |
+
|
465 |
+
def ckpt(block:Qwen2DecoderLayer, *block_args):
|
466 |
+
# if block.training and block.config.train.grad_cp == 1 and 'fsdp' not in block.config.train.strategy: # FSDP has its own checkpointing wrapper
|
467 |
+
#if "deepspeed" in block.config.train.strategy:
|
468 |
+
# results = deepspeed.checkpointing.checkpoint(block, *block_args)
|
469 |
+
#else:
|
470 |
+
# NOTE - both deepspeed.checkpointing.checkpoint and use_reentrant=True failed miserably (bad loss) when used in conjunction with requires_grad=False params and grad_cp with deepspeed
|
471 |
+
results = torch.utils.checkpoint.checkpoint(block, *block_args, use_reentrant=False)
|
472 |
+
# else:
|
473 |
+
# results = block(*block_args)
|
474 |
+
return results
|
475 |
+
|
476 |
+
class Qwen2Decoder(nn.Module):
|
477 |
+
def __init__(self, config:Qwen2Config):
|
478 |
+
super().__init__()
|
479 |
+
|
480 |
+
self.config = config
|
481 |
+
|
482 |
+
self.shared = Shared()
|
483 |
+
|
484 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) #, config.vocab_padding_idx)
|
485 |
+
self.layers = nn.ModuleList(
|
486 |
+
[Qwen2DecoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)]
|
487 |
+
)
|
488 |
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
489 |
+
|
490 |
+
def forward_preamble(self, x, last_model_state:ModelState|None = None, ):
|
491 |
+
config = self.config
|
492 |
+
|
493 |
+
B, T, C = x.size()
|
494 |
+
|
495 |
+
shared = self.shared
|
496 |
+
#if config.rope is not None and shared.angles.size(0) == 0:
|
497 |
+
# shared.angles = generate_rotary_embedding(config.max_position_embeddings, config.head_size, config.rope_theta).to(self.norm.weight)
|
498 |
+
|
499 |
+
assert (shared.angles.size(0) == 0 or T <= shared.angles.size(0)) or (shared.bias_mask.size(0) == 0 or T <= shared.bias_mask.size(0))
|
500 |
+
|
501 |
+
# might need to be true in the future for BPTT support
|
502 |
+
requires_grad = self.training
|
503 |
+
if last_model_state is None:
|
504 |
+
last_model_state = ModelState()
|
505 |
+
for layer_id in range(config.num_hidden_layers):
|
506 |
+
layer = self.layers[layer_id]
|
507 |
+
last_model_state.block_states.append(BlockState(
|
508 |
+
layer.default_time_mix_state_factory(x, config, requires_grad),
|
509 |
+
layer.default_channel_mix_state_factory(x, config, requires_grad),
|
510 |
+
))
|
511 |
+
|
512 |
+
return last_model_state
|
513 |
+
|
514 |
+
def forward(self, token_ids:Tensor|list, last_model_state:ModelState|None = None, output_hidden_states:bool=False, output_attentions:bool=False, output_post_attention_hidden_states:bool=False):
|
515 |
+
config = self.config
|
516 |
+
if isinstance(token_ids, Tensor):
|
517 |
+
B, T = token_ids.size()
|
518 |
+
else:
|
519 |
+
B = 1
|
520 |
+
T = len(token_ids)
|
521 |
+
token_ids = torch.tensor(token_ids, device=self.embed_tokens.weight.device, dtype=torch.long, requires_grad=False)[None, :]
|
522 |
+
|
523 |
+
x = self.embed_tokens(token_ids)
|
524 |
+
|
525 |
+
last_model_state = self.forward_preamble(x, last_model_state)
|
526 |
+
|
527 |
+
hidden_states_outputs, attentions_outputs, post_attention_hidden_states_outputs = (), (), ()
|
528 |
+
student_hidden_states_outputs, student_attentions_outputs, student_post_attention_hidden_states_outputs = (), (), ()
|
529 |
+
if output_hidden_states:
|
530 |
+
hidden_states_outputs += (x,)
|
531 |
+
student_hidden_states_outputs += (x,)
|
532 |
+
for decoder_layer in self.layers:
|
533 |
+
x, s, attentions, post_attention_hidden_states, student_attentions, student_post_attention_hidden_states = ckpt(decoder_layer, x, last_model_state, self.shared, output_attentions, output_post_attention_hidden_states)
|
534 |
+
hidden_states_outputs += (x,)
|
535 |
+
student_hidden_states_outputs += (x,)
|
536 |
+
if output_attentions:
|
537 |
+
attentions_outputs += (attentions,)
|
538 |
+
student_attentions_outputs += (student_attentions,)
|
539 |
+
if output_post_attention_hidden_states:
|
540 |
+
post_attention_hidden_states_outputs += (post_attention_hidden_states,)
|
541 |
+
student_post_attention_hidden_states_outputs += (student_post_attention_hidden_states,)
|
542 |
+
|
543 |
+
x = self.norm(x)
|
544 |
+
return LLMOutput(x, last_model_state, hidden_states_outputs, attentions_outputs, post_attention_hidden_states_outputs, student_attentions_outputs, student_post_attention_hidden_states_outputs)
|
545 |
+
#return x, last_model_state, hidden_states_outputs, attentions_outputs, post_attention_hidden_states_outputs # FIXME - not updating state at all
|
546 |
+
|
547 |
+
class Model_qwen2(nn.Module): # Qwen2CausalLM
|
548 |
+
def __init__(self, config:Qwen2Config):
|
549 |
+
super().__init__()
|
550 |
+
|
551 |
+
self.config = config
|
552 |
+
|
553 |
+
self.model = None
|
554 |
+
|
555 |
+
self.configure_model()
|
556 |
+
|
557 |
+
def configure_model(self):
|
558 |
+
if self.model is not None:
|
559 |
+
return
|
560 |
+
|
561 |
+
self.model = Qwen2Decoder(self.config)
|
562 |
+
|
563 |
+
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
|
564 |
+
|
565 |
+
def forward(self, input_ids:Tensor|list, last_model_state:ModelState|None = None, output_hidden_states:bool=False, output_attentions:bool=False, output_post_attention_hidden_states:bool=False):
|
566 |
+
#print("teacher q min, max", float(self.model.layers[0].self_attn.q_proj.weight.min()), float(self.model.layers[0].self_attn.q_proj.weight.max()))
|
567 |
+
results = self.model(input_ids, last_model_state=last_model_state, output_hidden_states=output_hidden_states, output_attentions=output_attentions, output_post_attention_hidden_states=output_post_attention_hidden_states)
|
568 |
+
results.logits = self.lm_head(results.logits)
|
569 |
+
return results
|
570 |
+
|
571 |
+
def get_optim_groups(self):
|
572 |
+
# separates groups for weight decay and non-weight decay
|
573 |
+
|
574 |
+
config = self.config
|
575 |
+
|
576 |
+
weight_decay = 0.0 # train_config.weight_decay
|
577 |
+
|
578 |
+
# if train_config.attention_distillation_stage in (1, 2):
|
579 |
+
# self.requires_grad_(False)
|
580 |
+
# for decoder_layer in self.model.layers:
|
581 |
+
# if train_config.attention_distillation_stage == 1:
|
582 |
+
# decoder_layer.self_attn.time_maa_x.requires_grad_(True)
|
583 |
+
# decoder_layer.self_attn.time_maa_r.requires_grad_(True)
|
584 |
+
# decoder_layer.self_attn.time_maa_k.requires_grad_(True)
|
585 |
+
# decoder_layer.self_attn.time_maa_w.requires_grad_(True)
|
586 |
+
# decoder_layer.self_attn.time_maa_w1.requires_grad_(True)
|
587 |
+
# decoder_layer.self_attn.time_maa_w2.requires_grad_(True)
|
588 |
+
# decoder_layer.self_attn.time_decay.requires_grad_(True)
|
589 |
+
# decoder_layer.self_attn.time_decay_w1.requires_grad_(True)
|
590 |
+
# decoder_layer.self_attn.time_decay_w2.requires_grad_(True)
|
591 |
+
# # FIXME - wow we removed q, k here by accident and it.. helped??!?!
|
592 |
+
# # decoder_layer.self_attn.q_proj.requires_grad_(True)
|
593 |
+
# # decoder_layer.self_attn.k_proj.requires_grad_(True)
|
594 |
+
# elif train_config.attention_distillation_stage == 2:
|
595 |
+
# decoder_layer.self_attn.requires_grad_(True)
|
596 |
+
|
597 |
+
# FIXME - remove these for full training
|
598 |
+
# for decoder_layer in self.model.layers:
|
599 |
+
# decoder_layer.post_attention_layernorm.requires_grad_(False)
|
600 |
+
# decoder_layer.mlp.requires_grad_(False)
|
601 |
+
# self.model.embed_tokens.requires_grad_(False)
|
602 |
+
# self.model.norm.requires_grad_(False)
|
603 |
+
# self.lm_head.requires_grad_(False)
|
604 |
+
|
605 |
+
# JIT at last minute
|
606 |
+
for decoder_layer in self.model.layers:
|
607 |
+
decoder_layer.self_attn = TJIT(decoder_layer.self_attn)
|
608 |
+
decoder_layer.mlp = TJIT(decoder_layer.mlp)
|
609 |
+
|
610 |
+
lr_decay = set()
|
611 |
+
lr_1x = set()
|
612 |
+
lr_fp32 = set()
|
613 |
+
for n, p in self.named_parameters():
|
614 |
+
if not p.requires_grad:
|
615 |
+
continue
|
616 |
+
if 'lm_head.weight' in n or 'embed_tokens.weight' in n:
|
617 |
+
lr_fp32.add(n)
|
618 |
+
elif (len(p.squeeze().shape) >= 2) and (weight_decay > 0):
|
619 |
+
lr_decay.add(n)
|
620 |
+
else:
|
621 |
+
lr_1x.add(n)
|
622 |
+
|
623 |
+
param_dict = {n: p for n, p in self.named_parameters()}
|
624 |
+
param_check = list(lr_decay) + list(lr_1x) + list(lr_fp32)
|
625 |
+
#if not train_config.load_partial and (train_config.teacher is None or train_config.teacher.attention_distillation_stage ==3):
|
626 |
+
# assert sorted(param_dict) == sorted(param_check)
|
627 |
+
|
628 |
+
lr_decay = sorted(list(lr_decay))
|
629 |
+
lr_1x = sorted(list(lr_1x))
|
630 |
+
lr_fp32 = sorted(list(lr_fp32))
|
631 |
+
|
632 |
+
print('decay', lr_decay, '\n')
|
633 |
+
print('1x', lr_1x, '\n')
|
634 |
+
print('fp32', lr_fp32, '\n')
|
635 |
+
|
636 |
+
|
637 |
+
optim_groups = [
|
638 |
+
{"params": [param_dict[n] for n in lr_1x], "weight_decay": 0.0, "use_fp16": True, "my_lr_scale": 1.0, 'name':'lr_1x'},
|
639 |
+
]
|
640 |
+
if len(lr_fp32) > 0:
|
641 |
+
optim_groups += [{"params": [param_dict[n] for n in lr_fp32], "weight_decay": weight_decay, "my_lr_scale": 1.0, 'name':'lr_tiny'}]
|
642 |
+
if len(lr_decay) > 0:
|
643 |
+
optim_groups += [{"params": [param_dict[n] for n in lr_decay], "weight_decay": weight_decay, "use_fp16": True, "my_lr_scale": 1.0, 'name':'lr_decay'}]
|
644 |
+
|
645 |
+
return optim_groups
|
646 |
+
|
647 |
+
# def _init_weights(self, module):
|
648 |
+
# std = 0.02 #self.config.initializer_range
|
649 |
+
# if isinstance(module, nn.Linear):
|
650 |
+
# module.weight.data.normal_(mean=0.0, std=std)
|
651 |
+
# if module.bias is not None:
|
652 |
+
# module.bias.data.zero_()
|
653 |
+
# elif isinstance(module, nn.Embedding):
|
654 |
+
# module.weight.data.normal_(mean=0.0, std=std)
|
655 |
+
# if module.padding_idx is not None:
|
656 |
+
# module.weight.data[module.padding_idx].zero_()
|
657 |
+
|
658 |
+
def init_all_weights(self):
|
659 |
+
# FIXME - we really need to init the gate here to identity (zero weights, ones bias) but instead we just won't init weights since they're all grabbed or set anyway
|
660 |
+
pass
|
661 |
+
# self.apply(self._init_weights)
|
662 |
+
# for n, p in self.named_parameters():
|
663 |
+
# requires_grad_temp = p.requires_grad
|
664 |
+
# p.requires_grad_(False)
|
665 |
+
# if n.endswith('.ln_x.weight'):
|
666 |
+
# layer_scale = (1+int(n.split('.')[2])) / self.config.model.n_layer
|
667 |
+
# print('.ln_x.weight layer', int(n.split('.')[2]), "scale", (layer_scale ** 0.7))
|
668 |
+
# p *= 0.0
|
669 |
+
# p += (layer_scale ** 0.7)
|
670 |
+
# p.requires_grad = requires_grad_temp
|
run_lm_eval.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, os
|
2 |
+
import transformers # just for a bugfix for 0.4.2 of lm_eval
|
3 |
+
|
4 |
+
import torch
|
5 |
+
torch.backends.cudnn.benchmark = True
|
6 |
+
torch.backends.cudnn.allow_tf32 = True
|
7 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
8 |
+
|
9 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
|
10 |
+
from configs import parse_cmdline_configs
|
11 |
+
from pydoc import locate
|
12 |
+
|
13 |
+
from lm_eval import evaluator
|
14 |
+
from lm_eval.models.huggingface import HFLM
|
15 |
+
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Callable
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class CLI_Config:
|
21 |
+
tokenizer_path: str
|
22 |
+
model_path: str
|
23 |
+
attn_path: str = 'rwkv6attn.RWKV6Attention'
|
24 |
+
tasks: str = 'lambada_openai' # arc_challenge, arc_easy, headqa, openbookqa, hellaswag, winogrande, piqa, record, copa, storycloze_2016
|
25 |
+
bsz: int|str = 'auto'
|
26 |
+
precision: int | str = 'bf16'
|
27 |
+
num_fewshot: int = 0
|
28 |
+
attn_classes_path: str = 'transformers.models.qwen2.modeling_qwen2.QWEN2_ATTENTION_CLASSES' # 'transformers.models.llama.modeling_llama.LLAMA_ATTENTION_CLASSES'
|
29 |
+
seed: int | None = None
|
30 |
+
train:Any = None
|
31 |
+
|
32 |
+
config, errors = parse_cmdline_configs(sys.argv[1:], CLI_Config)
|
33 |
+
if errors != '':
|
34 |
+
print(errors)
|
35 |
+
exit()
|
36 |
+
|
37 |
+
match config.precision:
|
38 |
+
case 32:
|
39 |
+
dtype = torch.float32
|
40 |
+
case '32':
|
41 |
+
dtype = torch.float32
|
42 |
+
case 16:
|
43 |
+
dtype = torch.float16
|
44 |
+
case '16':
|
45 |
+
dtype = torch.float16
|
46 |
+
case 'bf16':
|
47 |
+
dtype = torch.bfloat16
|
48 |
+
case _:
|
49 |
+
print("Bad precision type specified")
|
50 |
+
exit()
|
51 |
+
|
52 |
+
# avoid 1000 huggingface warnings "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...""
|
53 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
54 |
+
|
55 |
+
print(f'Loading model - {config.model_path}')
|
56 |
+
|
57 |
+
model_config = AutoConfig.from_pretrained(config.model_path)
|
58 |
+
|
59 |
+
if config.model_path.startswith('.') or config.model_path.startswith('/'):
|
60 |
+
# replace attention classes
|
61 |
+
ReplacementSelfAttentionType = locate(config.attn_path)
|
62 |
+
assert isinstance(ReplacementSelfAttentionType, Callable)
|
63 |
+
attn_classes_dict = locate(config.attn_classes_path)
|
64 |
+
assert isinstance(attn_classes_dict, dict), 'could not find attention classes dict at path provided'
|
65 |
+
for key in list(attn_classes_dict.keys()):
|
66 |
+
attn_classes_dict[key] = ReplacementSelfAttentionType
|
67 |
+
|
68 |
+
model = AutoModelForCausalLM.from_pretrained(config.model_path, config=model_config, torch_dtype=dtype, device_map='cuda')
|
69 |
+
|
70 |
+
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
|
71 |
+
|
72 |
+
#device = 'cuda'
|
73 |
+
#model = model.to(device=device, dtype=dtype)
|
74 |
+
model.eval()
|
75 |
+
|
76 |
+
eval_tasks = config.tasks.split(',')
|
77 |
+
|
78 |
+
if config.seed is None:
|
79 |
+
config.seed = 1234
|
80 |
+
|
81 |
+
adapter = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=config.bsz)
|
82 |
+
with torch.no_grad():
|
83 |
+
with torch.amp.autocast(device_type='cuda', dtype=dtype):
|
84 |
+
results = evaluator.simple_evaluate(
|
85 |
+
model=adapter,
|
86 |
+
tasks=eval_tasks,
|
87 |
+
#provide_description=False,
|
88 |
+
num_fewshot=config.num_fewshot,
|
89 |
+
limit=None,
|
90 |
+
bootstrap_iters=10000,
|
91 |
+
numpy_random_seed = config.seed,
|
92 |
+
torch_random_seed = config.seed,
|
93 |
+
fewshot_random_seed = config.seed,
|
94 |
+
)
|
95 |
+
|
96 |
+
print(results['results'])
|
tokenization_rwkv6qwen2.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer
|
2 |
+
|
3 |
+
class RWKV6Qwen2Tokenizer(Qwen2Tokenizer):
|
4 |
+
pass
|
tokenization_rwkv6qwen2_fast.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
|
2 |
+
|
3 |
+
class RWKV6Qwen2TokenizerFast(Qwen2TokenizerFast):
|
4 |
+
pass
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_prefix_space": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"151643": {
|
6 |
+
"content": "<|endoftext|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"151644": {
|
14 |
+
"content": "<|im_start|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"151645": {
|
22 |
+
"content": "<|im_end|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"151646": {
|
30 |
+
"content": "<|object_ref_start|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"151647": {
|
38 |
+
"content": "<|object_ref_end|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"151648": {
|
46 |
+
"content": "<|box_start|>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": false,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"151649": {
|
54 |
+
"content": "<|box_end|>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": false,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"151650": {
|
62 |
+
"content": "<|quad_start|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"151651": {
|
70 |
+
"content": "<|quad_end|>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": false,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"151652": {
|
78 |
+
"content": "<|vision_start|>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": false,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"151653": {
|
86 |
+
"content": "<|vision_end|>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": false,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"151654": {
|
94 |
+
"content": "<|vision_pad|>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": false,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"151655": {
|
102 |
+
"content": "<|image_pad|>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": false,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"151656": {
|
110 |
+
"content": "<|video_pad|>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": false,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
},
|
117 |
+
"151657": {
|
118 |
+
"content": "<tool_call>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": false,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": false
|
124 |
+
},
|
125 |
+
"151658": {
|
126 |
+
"content": "</tool_call>",
|
127 |
+
"lstrip": false,
|
128 |
+
"normalized": false,
|
129 |
+
"rstrip": false,
|
130 |
+
"single_word": false,
|
131 |
+
"special": false
|
132 |
+
},
|
133 |
+
"151659": {
|
134 |
+
"content": "<|fim_prefix|>",
|
135 |
+
"lstrip": false,
|
136 |
+
"normalized": false,
|
137 |
+
"rstrip": false,
|
138 |
+
"single_word": false,
|
139 |
+
"special": false
|
140 |
+
},
|
141 |
+
"151660": {
|
142 |
+
"content": "<|fim_middle|>",
|
143 |
+
"lstrip": false,
|
144 |
+
"normalized": false,
|
145 |
+
"rstrip": false,
|
146 |
+
"single_word": false,
|
147 |
+
"special": false
|
148 |
+
},
|
149 |
+
"151661": {
|
150 |
+
"content": "<|fim_suffix|>",
|
151 |
+
"lstrip": false,
|
152 |
+
"normalized": false,
|
153 |
+
"rstrip": false,
|
154 |
+
"single_word": false,
|
155 |
+
"special": false
|
156 |
+
},
|
157 |
+
"151662": {
|
158 |
+
"content": "<|fim_pad|>",
|
159 |
+
"lstrip": false,
|
160 |
+
"normalized": false,
|
161 |
+
"rstrip": false,
|
162 |
+
"single_word": false,
|
163 |
+
"special": false
|
164 |
+
},
|
165 |
+
"151663": {
|
166 |
+
"content": "<|repo_name|>",
|
167 |
+
"lstrip": false,
|
168 |
+
"normalized": false,
|
169 |
+
"rstrip": false,
|
170 |
+
"single_word": false,
|
171 |
+
"special": false
|
172 |
+
},
|
173 |
+
"151664": {
|
174 |
+
"content": "<|file_sep|>",
|
175 |
+
"lstrip": false,
|
176 |
+
"normalized": false,
|
177 |
+
"rstrip": false,
|
178 |
+
"single_word": false,
|
179 |
+
"special": false
|
180 |
+
}
|
181 |
+
},
|
182 |
+
"additional_special_tokens": [
|
183 |
+
"<|im_start|>",
|
184 |
+
"<|im_end|>",
|
185 |
+
"<|object_ref_start|>",
|
186 |
+
"<|object_ref_end|>",
|
187 |
+
"<|box_start|>",
|
188 |
+
"<|box_end|>",
|
189 |
+
"<|quad_start|>",
|
190 |
+
"<|quad_end|>",
|
191 |
+
"<|vision_start|>",
|
192 |
+
"<|vision_end|>",
|
193 |
+
"<|vision_pad|>",
|
194 |
+
"<|image_pad|>",
|
195 |
+
"<|video_pad|>"
|
196 |
+
],
|
197 |
+
"bos_token": null,
|
198 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
199 |
+
"clean_up_tokenization_spaces": false,
|
200 |
+
"eos_token": "<|im_end|>",
|
201 |
+
"errors": "replace",
|
202 |
+
"model_max_length": 131072,
|
203 |
+
"pad_token": "<|endoftext|>",
|
204 |
+
"split_special_tokens": false,
|
205 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
206 |
+
"unk_token": null
|
207 |
+
}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|