smcleish commited on
Commit
aaa9b6f
·
verified ·
1 Parent(s): 146a5cc

Upload RavenForCausalLM

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_checkpoint_impl": "per-iteration",
3
+ "architecture_class_name": "RecurrentGPT",
4
+ "architectures": [
5
+ "RavenForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "raven_config_minimal.RavenConfig",
9
+ "AutoModelForCausalLM": "raven_modeling_minimal.RavenForCausalLM"
10
+ },
11
+ "bias": false,
12
+ "block_class_name": "SandwichBlock",
13
+ "block_size": 4096,
14
+ "effective_expected_depth": 132,
15
+ "head_dim": 96,
16
+ "init_orthogonal": false,
17
+ "init_strategy": "takase",
18
+ "init_values": {
19
+ "embed_scale": 72.6636084983398,
20
+ "embedding": 0.008703882797784892,
21
+ "out_proj": 0.0005356869554443541,
22
+ "std": 0.008703882797784892
23
+ },
24
+ "injection_type": "linear",
25
+ "intermediate_size": 17920,
26
+ "mean_backprop_depth": 8,
27
+ "mean_recurrence": 32,
28
+ "mlp_class_name": "GatedMLP",
29
+ "model_type": "huginn_raven",
30
+ "n_embd": 5280,
31
+ "n_heads": 55,
32
+ "n_layers": 8,
33
+ "n_layers_in_coda": 2,
34
+ "n_layers_in_prelude": 2,
35
+ "n_layers_in_recurrent_block": 4,
36
+ "nonlin_name": "SiLU",
37
+ "norm_class_name": "RMSNorm_llama",
38
+ "norm_eps": 1e-06,
39
+ "num_key_value_heads": 55,
40
+ "padded_vocab_size": 65536,
41
+ "padding_multiple": 4096,
42
+ "qk_bias": true,
43
+ "rope_base": 50000,
44
+ "sampling_scheme": "poisson-lognormal-filling",
45
+ "state_init": "like-init",
46
+ "tie_embeddings": true,
47
+ "torch_dtype": "float32",
48
+ "transformers_version": "4.46.3",
49
+ "vocab_size": 65536
50
+ }
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.46.3"
4
+ }
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a0111d4cfdf344d4ae1ecb4a6630e62765ad35f3a2f12c0a2623e0607a7fe72
3
+ size 4771970936
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd1a99baee747fc18ce06f6bde9680f61073e6e1663b89838433cbe38e3a5f39
3
+ size 4744780096
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8010845abce6777bc55299da83a310e546ce000a603ac22e2c6922f4f067713f
3
+ size 4744737616
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6231039c098dd83b7646c857f47107d7573aa94b6091ca1ce1436fdbd3720827
3
+ size 1384120448
model.safetensors.index.json ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 15645600384
4
+ },
5
+ "weight_map": {
6
+ "freqs_cis": "model-00001-of-00004.safetensors",
7
+ "lm_head.weight": "model-00004-of-00004.safetensors",
8
+ "transformer.adapter.weight": "model-00001-of-00004.safetensors",
9
+ "transformer.coda.0.attn.Wqkv.weight": "model-00003-of-00004.safetensors",
10
+ "transformer.coda.0.attn.proj.weight": "model-00003-of-00004.safetensors",
11
+ "transformer.coda.0.attn.qk_bias": "model-00003-of-00004.safetensors",
12
+ "transformer.coda.0.mlp.fc.weight": "model-00003-of-00004.safetensors",
13
+ "transformer.coda.0.mlp.proj.weight": "model-00003-of-00004.safetensors",
14
+ "transformer.coda.0.norm_1.weight": "model-00003-of-00004.safetensors",
15
+ "transformer.coda.0.norm_2.weight": "model-00003-of-00004.safetensors",
16
+ "transformer.coda.0.norm_3.weight": "model-00003-of-00004.safetensors",
17
+ "transformer.coda.0.norm_4.weight": "model-00003-of-00004.safetensors",
18
+ "transformer.coda.1.attn.Wqkv.weight": "model-00003-of-00004.safetensors",
19
+ "transformer.coda.1.attn.proj.weight": "model-00003-of-00004.safetensors",
20
+ "transformer.coda.1.attn.qk_bias": "model-00003-of-00004.safetensors",
21
+ "transformer.coda.1.mlp.fc.weight": "model-00003-of-00004.safetensors",
22
+ "transformer.coda.1.mlp.proj.weight": "model-00003-of-00004.safetensors",
23
+ "transformer.coda.1.norm_1.weight": "model-00003-of-00004.safetensors",
24
+ "transformer.coda.1.norm_2.weight": "model-00003-of-00004.safetensors",
25
+ "transformer.coda.1.norm_3.weight": "model-00003-of-00004.safetensors",
26
+ "transformer.coda.1.norm_4.weight": "model-00003-of-00004.safetensors",
27
+ "transformer.core_block.0.attn.Wqkv.weight": "model-00002-of-00004.safetensors",
28
+ "transformer.core_block.0.attn.proj.weight": "model-00002-of-00004.safetensors",
29
+ "transformer.core_block.0.attn.qk_bias": "model-00001-of-00004.safetensors",
30
+ "transformer.core_block.0.mlp.fc.weight": "model-00002-of-00004.safetensors",
31
+ "transformer.core_block.0.mlp.proj.weight": "model-00002-of-00004.safetensors",
32
+ "transformer.core_block.0.norm_1.weight": "model-00001-of-00004.safetensors",
33
+ "transformer.core_block.0.norm_2.weight": "model-00002-of-00004.safetensors",
34
+ "transformer.core_block.0.norm_3.weight": "model-00002-of-00004.safetensors",
35
+ "transformer.core_block.0.norm_4.weight": "model-00002-of-00004.safetensors",
36
+ "transformer.core_block.1.attn.Wqkv.weight": "model-00002-of-00004.safetensors",
37
+ "transformer.core_block.1.attn.proj.weight": "model-00002-of-00004.safetensors",
38
+ "transformer.core_block.1.attn.qk_bias": "model-00002-of-00004.safetensors",
39
+ "transformer.core_block.1.mlp.fc.weight": "model-00002-of-00004.safetensors",
40
+ "transformer.core_block.1.mlp.proj.weight": "model-00002-of-00004.safetensors",
41
+ "transformer.core_block.1.norm_1.weight": "model-00002-of-00004.safetensors",
42
+ "transformer.core_block.1.norm_2.weight": "model-00002-of-00004.safetensors",
43
+ "transformer.core_block.1.norm_3.weight": "model-00002-of-00004.safetensors",
44
+ "transformer.core_block.1.norm_4.weight": "model-00002-of-00004.safetensors",
45
+ "transformer.core_block.2.attn.Wqkv.weight": "model-00002-of-00004.safetensors",
46
+ "transformer.core_block.2.attn.proj.weight": "model-00002-of-00004.safetensors",
47
+ "transformer.core_block.2.attn.qk_bias": "model-00002-of-00004.safetensors",
48
+ "transformer.core_block.2.mlp.fc.weight": "model-00002-of-00004.safetensors",
49
+ "transformer.core_block.2.mlp.proj.weight": "model-00002-of-00004.safetensors",
50
+ "transformer.core_block.2.norm_1.weight": "model-00002-of-00004.safetensors",
51
+ "transformer.core_block.2.norm_2.weight": "model-00002-of-00004.safetensors",
52
+ "transformer.core_block.2.norm_3.weight": "model-00002-of-00004.safetensors",
53
+ "transformer.core_block.2.norm_4.weight": "model-00002-of-00004.safetensors",
54
+ "transformer.core_block.3.attn.Wqkv.weight": "model-00003-of-00004.safetensors",
55
+ "transformer.core_block.3.attn.proj.weight": "model-00003-of-00004.safetensors",
56
+ "transformer.core_block.3.attn.qk_bias": "model-00002-of-00004.safetensors",
57
+ "transformer.core_block.3.mlp.fc.weight": "model-00003-of-00004.safetensors",
58
+ "transformer.core_block.3.mlp.proj.weight": "model-00003-of-00004.safetensors",
59
+ "transformer.core_block.3.norm_1.weight": "model-00002-of-00004.safetensors",
60
+ "transformer.core_block.3.norm_2.weight": "model-00003-of-00004.safetensors",
61
+ "transformer.core_block.3.norm_3.weight": "model-00003-of-00004.safetensors",
62
+ "transformer.core_block.3.norm_4.weight": "model-00003-of-00004.safetensors",
63
+ "transformer.ln_f.weight": "model-00003-of-00004.safetensors",
64
+ "transformer.prelude.0.attn.Wqkv.weight": "model-00001-of-00004.safetensors",
65
+ "transformer.prelude.0.attn.proj.weight": "model-00001-of-00004.safetensors",
66
+ "transformer.prelude.0.attn.qk_bias": "model-00001-of-00004.safetensors",
67
+ "transformer.prelude.0.mlp.fc.weight": "model-00001-of-00004.safetensors",
68
+ "transformer.prelude.0.mlp.proj.weight": "model-00001-of-00004.safetensors",
69
+ "transformer.prelude.0.norm_1.weight": "model-00001-of-00004.safetensors",
70
+ "transformer.prelude.0.norm_2.weight": "model-00001-of-00004.safetensors",
71
+ "transformer.prelude.0.norm_3.weight": "model-00001-of-00004.safetensors",
72
+ "transformer.prelude.0.norm_4.weight": "model-00001-of-00004.safetensors",
73
+ "transformer.prelude.1.attn.Wqkv.weight": "model-00001-of-00004.safetensors",
74
+ "transformer.prelude.1.attn.proj.weight": "model-00001-of-00004.safetensors",
75
+ "transformer.prelude.1.attn.qk_bias": "model-00001-of-00004.safetensors",
76
+ "transformer.prelude.1.mlp.fc.weight": "model-00001-of-00004.safetensors",
77
+ "transformer.prelude.1.mlp.proj.weight": "model-00001-of-00004.safetensors",
78
+ "transformer.prelude.1.norm_1.weight": "model-00001-of-00004.safetensors",
79
+ "transformer.prelude.1.norm_2.weight": "model-00001-of-00004.safetensors",
80
+ "transformer.prelude.1.norm_3.weight": "model-00001-of-00004.safetensors",
81
+ "transformer.prelude.1.norm_4.weight": "model-00001-of-00004.safetensors",
82
+ "transformer.wte.weight": "model-00001-of-00004.safetensors"
83
+ }
84
+ }
raven_config_minimal.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A HuggingFace-style model configuration."""
2
+
3
+ from transformers import PretrainedConfig
4
+ from math import sqrt
5
+
6
+
7
+ class RavenConfig(PretrainedConfig):
8
+ model_type = "huginn_raven"
9
+ keys_to_ignore_at_inference = [""]
10
+ attribute_map = {"num_attention_heads": "n_heads", "hidden_size": "n_embd", "num_hidden_layers": "n_layers"}
11
+
12
+ def __init__(
13
+ self,
14
+ n_embd: int = 5280,
15
+ n_heads: int = 55,
16
+ n_layers: int = 8, # total of prelude + recurrent + coda
17
+ block_size: int = 4096,
18
+ vocab_size: int = 65536,
19
+ padding_multiple: int = 4096,
20
+ tie_embeddings: bool = True,
21
+ intermediate_size: int = 17920,
22
+ bias: bool = False,
23
+ architecture_class_name: str = "RecurrentGPT",
24
+ block_class_name: str = "SandwichBlock",
25
+ norm_class_name: str = "RMSNorm_llama",
26
+ norm_eps: float = 0.000001,
27
+ mlp_class_name: str = "GatedMLP",
28
+ nonlin_name: str = "SiLU",
29
+ init_strategy: str = "takase",
30
+ init_orthogonal: bool = False,
31
+ state_init: str = "like-init",
32
+ injection_type: str = "linear",
33
+ n_layers_in_recurrent_block: int = 4,
34
+ mean_recurrence: int = 32,
35
+ sampling_scheme: str = "poisson-lognormal-filling",
36
+ mean_backprop_depth: int = 8,
37
+ n_layers_in_prelude: int = 2,
38
+ n_layers_in_coda: int = 2,
39
+ qk_bias: bool = True,
40
+ activation_checkpoint_impl: str = "per-iteration",
41
+ rope_base: float = 50_000,
42
+ torch_dtype: str = "bfloat16",
43
+ transformers_version: str = "4.47.1",
44
+ **kwargs,
45
+ ):
46
+ self.n_embd = n_embd
47
+ self.n_heads = n_heads
48
+ self.n_layers = n_layers
49
+ self.block_size = block_size
50
+ self.vocab_size = self.padded_vocab_size = vocab_size
51
+ self.padding_multiple = padding_multiple
52
+ self.tie_embeddings = tie_embeddings
53
+ self.intermediate_size = intermediate_size
54
+ self.bias = bias
55
+ self.architecture_class_name = architecture_class_name
56
+ self.block_class_name = block_class_name
57
+ self.norm_class_name = norm_class_name
58
+ self.norm_eps = norm_eps
59
+ self.mlp_class_name = mlp_class_name
60
+ self.nonlin_name = nonlin_name
61
+ self.init_strategy = init_strategy
62
+ self.init_orthogonal = init_orthogonal
63
+ self.state_init = state_init
64
+ self.injection_type = injection_type
65
+ self.n_layers_in_recurrent_block = n_layers_in_recurrent_block
66
+ self.mean_recurrence = mean_recurrence
67
+ self.sampling_scheme = sampling_scheme
68
+ self.mean_backprop_depth = mean_backprop_depth
69
+ self.n_layers_in_prelude = n_layers_in_prelude
70
+ self.n_layers_in_coda = n_layers_in_coda
71
+ self.qk_bias = qk_bias
72
+ self.activation_checkpoint_impl = activation_checkpoint_impl
73
+ self.rope_base = rope_base
74
+ self.torch_dtype = torch_dtype # Added from JSON
75
+ self.transformers_version = transformers_version # Added from JSON
76
+ # Derived
77
+ self.num_key_value_heads = n_heads
78
+ self.num_attention_heads = n_heads
79
+ self.head_dim = n_embd // n_heads
80
+ self.effective_expected_depth = (
81
+ self.n_layers_in_prelude + self.n_layers_in_coda + self.n_layers_in_recurrent_block * self.mean_recurrence
82
+ )
83
+ self.init_values = {
84
+ "std": sqrt(2 / (5 * self.n_embd)),
85
+ "out_proj": sqrt(2 / (5 * self.n_embd)) / sqrt(2 * self.effective_expected_depth),
86
+ "embedding": sqrt(2 / (5 * self.n_embd)),
87
+ "embed_scale": sqrt(self.n_embd),
88
+ }
89
+
90
+ super().__init__(
91
+ # pad_token_id=65509,
92
+ # bos_token_id=65504,
93
+ # eos_token_id=65505,
94
+ tie_word_embeddings=tie_embeddings,
95
+ **kwargs,
96
+ )
raven_modeling_minimal.py ADDED
@@ -0,0 +1,909 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Minimal modeling.py file for HF compatibility and funny zero-shot experiments. Use only for inference."""
2
+
3
+ import torch
4
+ import math
5
+
6
+ from torch import Tensor
7
+ from dataclasses import dataclass
8
+ from typing import Optional, Union, Any
9
+
10
+ from .raven_config_minimal import RavenConfig
11
+ from transformers.cache_utils import Cache, DynamicCache
12
+
13
+ ###################### Huggingface Glue code I ##################################################################
14
+ from transformers import PreTrainedModel
15
+ from transformers.utils import ModelOutput
16
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
17
+
18
+ import torch.nn.functional as F
19
+ from transformers import GenerationConfig
20
+
21
+
22
+ class RavenPreTrainedModel(PreTrainedModel):
23
+ config_class = RavenConfig
24
+ base_model_prefix = "model"
25
+ supports_gradient_checkpointing = True
26
+ _no_split_modules = ["SandwichBlock"]
27
+ _skip_keys_device_placement = ["past_key_values"]
28
+ _supports_flash_attn_2 = True
29
+ _supports_sdpa = True
30
+ _supports_cache_class = True
31
+ _supports_quantized_cache = False
32
+ _supports_static_cache = False
33
+
34
+ def _init_weights(self, module):
35
+ print("Random Initialization not implemented.")
36
+
37
+
38
+ @dataclass
39
+ class CausalLMOutputRecurrentLatents(ModelOutput):
40
+ loss: Optional[torch.Tensor] = None
41
+ log_ppl: Optional[torch.Tensor] = None
42
+ logits: Optional[torch.Tensor] = None
43
+ past_key_values: Optional[Cache] = None
44
+ latent_states: Optional[torch.Tensor] = None
45
+ hidden_states: Optional[torch.Tensor] = None
46
+ attention_maps: Optional[dict[int, torch.Tensor]] = None
47
+ stats: Optional[dict] = None
48
+
49
+
50
+ ###################### Minimal implementation from here ############################################################
51
+
52
+
53
+ class RMSNorm(torch.nn.Module):
54
+ """Saner dtype handling and slightly better for fusion"""
55
+
56
+ def __init__(self, dim: int, eps: float = 1e-6):
57
+ super().__init__()
58
+ self.eps = eps
59
+ self.weight = torch.nn.Parameter(torch.ones(dim))
60
+
61
+ def _norm(self, x):
62
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
63
+
64
+ def forward(self, x):
65
+ with torch.autocast(enabled=False, device_type=x.device.type):
66
+ return self._norm(x.float()).type_as(x) * self.weight
67
+
68
+ def reset_parameters(self) -> None:
69
+ torch.nn.init.ones_(self.weight)
70
+
71
+
72
+ class HuginnDynamicCache(DynamicCache):
73
+ def __init__(self, lookup_strategy: str = "latest") -> None:
74
+ super().__init__()
75
+ self._seen_tokens = 0
76
+ self.key_cache: dict[int, dict[int, torch.Tensor]] = {}
77
+ self.value_cache: dict[int, dict[int, torch.Tensor]] = {}
78
+ # structure: cache[index_of_layer_or_recurrent_step][index_in_sequence]
79
+ # the cache is held uncoalesced because certain recurrent steps may be missing for some sequence ids if using
80
+ # per-token adaptive compute. In those cases, the "lookup_strategy" determines how to proceed
81
+ # Also, It is critical that the head indices do not overlap with the recurrent iteration indices
82
+ self.lookup_strategy = lookup_strategy
83
+
84
+ def update(
85
+ self,
86
+ key_states: torch.Tensor,
87
+ value_states: torch.Tensor,
88
+ step_idx: int,
89
+ lookup_strategy: Optional[str] = None,
90
+ ) -> tuple[torch.Tensor, torch.Tensor]:
91
+ lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy
92
+ # Init
93
+ if step_idx not in self.key_cache:
94
+ self.key_cache[step_idx] = {}
95
+ self.value_cache[step_idx] = {}
96
+ # Update the number of seen tokens, we assume that step_idx=0 (first prelude) is always hit
97
+ if step_idx == 0:
98
+ self._seen_tokens += key_states.shape[-2]
99
+ # Add entries to cache
100
+ for idx, entry in enumerate(key_states.unbind(dim=-2)):
101
+ assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx]
102
+ # print(f"Overwrote cache entry for step_idx {step_idx}") # likely the head
103
+ self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry
104
+ for idx, entry in enumerate(value_states.unbind(dim=-2)):
105
+ self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry
106
+
107
+ # Materialize past state based on lookup strategy:
108
+ if len(self.key_cache[step_idx]) == self._seen_tokens:
109
+ # All entries are present, materialize cache as normal
110
+ return (
111
+ torch.stack(list(self.key_cache[step_idx].values()), dim=-2),
112
+ torch.stack(list(self.value_cache[step_idx].values()), dim=-2),
113
+ )
114
+ else: # some entries where not previously computed
115
+ if lookup_strategy == "latest":
116
+ latest_keys = []
117
+ latest_values = []
118
+ for token_pos in range(self._seen_tokens):
119
+ # Find the latest step that has this token position
120
+ max_step = max((s for s in range(step_idx + 1) if token_pos in self.key_cache[s]), default=None)
121
+ if max_step is None:
122
+ raise ValueError(f"No cache entry found for token position {token_pos}")
123
+ latest_keys.append(self.key_cache[max_step][token_pos])
124
+ latest_values.append(self.value_cache[max_step][token_pos])
125
+ return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2)
126
+ elif lookup_strategy == "skip":
127
+ existing_keys = []
128
+ existing_values = []
129
+ for token_pos in range(self._seen_tokens):
130
+ if token_pos in self.key_cache[step_idx]:
131
+ existing_keys.append(self.key_cache[step_idx][token_pos])
132
+ existing_values.append(self.value_cache[step_idx][token_pos])
133
+ return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2)
134
+ elif lookup_strategy == "randomized": # sanity check
135
+ rand_keys = []
136
+ rand_values = []
137
+ for token_pos in range(self._seen_tokens):
138
+ # Find steps that have this token position
139
+ steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]]
140
+ rand_step = steps[torch.randint(len(steps), (1,))]
141
+ rand_keys.append(self.key_cache[rand_step][token_pos])
142
+ rand_values.append(self.value_cache[rand_step][token_pos])
143
+ return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2)
144
+ else:
145
+ raise ValueError(f"Unknown lookup strategy: {lookup_strategy}")
146
+
147
+ def reset(self) -> None:
148
+ """Reset the cache state."""
149
+ self._seen_tokens = 0
150
+ self.key_cache.clear()
151
+ self.value_cache.clear()
152
+
153
+ def get_seq_length(self, step_idx: int = 0) -> int:
154
+ return self._seen_tokens
155
+
156
+
157
+ class CausalSelfAttention(torch.nn.Module):
158
+ def __init__(self, config: RavenConfig) -> None:
159
+ super().__init__()
160
+ self.config = config
161
+ self.n_head = config.num_attention_heads
162
+ self.n_kv_heads = config.num_key_value_heads
163
+ self.head_dim = config.n_embd // self.n_head
164
+
165
+ shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim
166
+ self.chunks = [config.n_embd, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim]
167
+ self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False)
168
+ if config.qk_bias:
169
+ self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim))
170
+ self.proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=False)
171
+
172
+ def forward(
173
+ self,
174
+ x: Tensor,
175
+ freqs_cis: Tensor,
176
+ step_idx: int,
177
+ mask: Optional[Tensor] = None,
178
+ past_key_values: Optional[Cache] = None,
179
+ return_attn: bool = False,
180
+ ) -> tuple[Tensor, Optional[Tensor]]:
181
+ B, S, E = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
182
+ q, k, v = self.Wqkv(x).split(self.chunks, dim=2)
183
+ q = q.view(B, S, self.n_head, self.head_dim)
184
+ k = k.view(B, S, self.n_kv_heads, self.head_dim)
185
+ v = v.view(B, S, self.n_kv_heads, self.head_dim)
186
+ # bias?
187
+ if self.config.qk_bias:
188
+ q_bias, k_bias = self.qk_bias.split(1, dim=0)
189
+ q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype)
190
+ # apply rotary
191
+ q, k = apply_rotary_emb_complex_like(q, k, freqs_cis=freqs_cis)
192
+
193
+ q = q.transpose(1, 2) # (B, nh, S, hs)
194
+ k = k.transpose(1, 2)
195
+ v = v.transpose(1, 2)
196
+
197
+ if past_key_values is not None:
198
+ k, v = past_key_values.update(k, v, step_idx)
199
+
200
+ if return_attn:
201
+ y, attention_map = self.compute_eager_sdpa(q, k, v, attn_mask=mask)
202
+ else:
203
+ y = torch.nn.functional.scaled_dot_product_attention(
204
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=q.shape[2] > 1
205
+ )
206
+ y = y.transpose(1, 2).reshape(B, S, E).contiguous() # reshape is a view if possible (it mostly is)
207
+ return self.proj(y), attention_map if return_attn else None
208
+
209
+ def compute_eager_sdpa(self, q, k, v, attn_mask):
210
+ scale = 1.0 / math.sqrt(self.head_dim)
211
+ scores = torch.matmul(q, k.transpose(-2, -1)) * scale
212
+
213
+ if attn_mask is not None:
214
+ scores = scores + attn_mask
215
+ if q.shape[2] > 1:
216
+ causal_mask = torch.triu(torch.ones(q.shape[2], q.shape[2]), diagonal=1).bool()
217
+ scores.masked_fill_(causal_mask.to(scores.device), float("-inf"))
218
+
219
+ attention_weights = torch.nn.functional.softmax(scores, dim=-1)
220
+ y = torch.matmul(attention_weights, v)
221
+ return y, attention_weights.max(dim=1)[0]
222
+
223
+
224
+ class GatedMLP(torch.nn.Module):
225
+ def __init__(self, config: RavenConfig, in_features: int = 0) -> None:
226
+ super().__init__()
227
+ in_features = config.n_embd if in_features == 0 else in_features
228
+ self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False)
229
+
230
+ self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False)
231
+ self.nonlin = torch.nn.SiLU()
232
+
233
+ def forward(self, x: Tensor) -> Tensor:
234
+ # modified to single FC layer to improve parallelism
235
+ x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1)
236
+ x = self.nonlin(x_fc_1) * x_fc_2
237
+ return self.proj(x)
238
+
239
+
240
+ class SandwichBlock(torch.nn.Module):
241
+ expanded = False
242
+
243
+ def __init__(self, config: RavenConfig, layer_id: int) -> None:
244
+ super().__init__()
245
+ self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps)
246
+ self.attn = CausalSelfAttention(config)
247
+ self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps)
248
+ self.mlp = GatedMLP(config)
249
+ self.norm_3 = RMSNorm(config.n_embd, eps=config.norm_eps)
250
+ self.norm_4 = RMSNorm(config.n_embd, eps=config.norm_eps)
251
+ self.layer_id = layer_id
252
+
253
+ def forward(
254
+ self,
255
+ x: Tensor,
256
+ freqs_cis: Tensor,
257
+ step_idx: int,
258
+ mask: Optional[Tensor] = None,
259
+ past_key_values: Optional[Cache] = None,
260
+ return_attn: bool = False,
261
+ ) -> tuple[Tensor, Optional[Tensor]]:
262
+ attn_out, attn_map = self.attn(self.norm_1(x), freqs_cis, step_idx, mask, past_key_values, return_attn)
263
+ x = self.norm_2(attn_out + x)
264
+ x = self.norm_4(self.mlp(self.norm_3(x)) + x)
265
+ return x, attn_map
266
+
267
+
268
+ class RavenForCausalLM(RavenPreTrainedModel):
269
+ def __init__(
270
+ self,
271
+ config: RavenConfig,
272
+ ) -> None:
273
+ super().__init__(config)
274
+ self.config = config
275
+
276
+ # Transformer layers
277
+ prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude))
278
+ adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias)
279
+ core_block = torch.nn.ModuleList(
280
+ SandwichBlock(config, layer_id=i + config.n_layers_in_prelude)
281
+ for i in range(config.n_layers_in_recurrent_block)
282
+ )
283
+ o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence
284
+ coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda))
285
+
286
+ self.transformer = torch.nn.ModuleDict(
287
+ dict(
288
+ wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd),
289
+ prelude=prelude,
290
+ adapter=adapter,
291
+ core_block=core_block,
292
+ coda=coda,
293
+ ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), # used twice :>
294
+ )
295
+ )
296
+ self.emb_scale = config.init_values["embed_scale"]
297
+ # Head
298
+ self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
299
+ if self.config.tie_embeddings:
300
+ self.lm_head.weight = self.transformer.wte.weight
301
+ # rope
302
+ self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
303
+
304
+ def _precompute_freqs_cis(self):
305
+ # can actually be a buffer now, and remains in fp32! (at least in the settings I tested)
306
+ freqs_cis = precompute_freqs_cis(
307
+ self.config.n_embd // self.config.num_attention_heads, self.config.block_size, self.config.rope_base, 1
308
+ )
309
+ return freqs_cis
310
+
311
+ def forward(
312
+ self,
313
+ input_ids: torch.Tensor,
314
+ input_embeds: Optional[torch.Tensor] = None,
315
+ input_states: Optional[torch.Tensor] = None,
316
+ attention_mask: Optional[torch.Tensor] = None,
317
+ position_ids: Optional[torch.Tensor] = None,
318
+ labels: Optional[torch.Tensor] = None,
319
+ num_steps: Optional[torch.Tensor] = None,
320
+ past_key_values: Optional[Cache] = None,
321
+ output_details: dict = {
322
+ "return_logits": True,
323
+ "return_latents": True,
324
+ "return_attention": False,
325
+ "return_head": False,
326
+ "return_stats": True,
327
+ },
328
+ use_cache: bool = False,
329
+ cache_position: Optional[torch.Tensor] = None,
330
+ **kwargs,
331
+ ) -> CausalLMOutputRecurrentLatents:
332
+ # Support multiple position formats:
333
+ if position_ids is None and cache_position is None:
334
+ freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
335
+ elif position_ids is not None:
336
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
337
+ elif cache_position is not None:
338
+ freqs_cis = self.freqs_cis[:, cache_position]
339
+
340
+ if input_embeds is None:
341
+ input_embeds = self.transformer.wte(input_ids)
342
+
343
+ if self.emb_scale != 1:
344
+ input_embeds = input_embeds * self.emb_scale # type: ignore
345
+
346
+ if use_cache and past_key_values is None:
347
+ past_key_values = HuginnDynamicCache()
348
+ attn_maps = {}
349
+ return_attn = output_details["return_attention"]
350
+
351
+ # Non-recurrent prelude
352
+ for block_idx, block in enumerate(self.transformer.prelude):
353
+ input_embeds, attn_map = block(
354
+ input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
355
+ )
356
+ attn_maps[block_idx] = attn_map
357
+
358
+ # Main recurrence
359
+ x, num_steps_no_grad, num_steps_with_grad, xk, block_idx, attn_maps = self.iterate_forward(
360
+ input_embeds, # type: ignore
361
+ input_states,
362
+ freqs_cis,
363
+ block_idx,
364
+ attention_mask,
365
+ past_key_values,
366
+ num_steps,
367
+ attn_maps,
368
+ )
369
+ latent_states = x.clone().detach()
370
+
371
+ # Coda layers
372
+ for block_idx, block in enumerate(self.transformer.coda, start=1):
373
+ x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values, return_attn)
374
+ attn_maps[-block_idx] = attn_map
375
+ x = self.transformer.ln_f(x)
376
+
377
+ # Prediction head, assuming labels really are labels and not equal to input_ids
378
+ if labels is not None:
379
+ logits = self.lm_head(x).float()
380
+ loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1))
381
+ log_ppl = loss.clone().detach()
382
+ else:
383
+ logits = self.lm_head(x).float()
384
+ loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0)
385
+
386
+ return CausalLMOutputRecurrentLatents(
387
+ loss=loss,
388
+ log_ppl=log_ppl,
389
+ logits=logits if output_details["return_logits"] else None,
390
+ past_key_values=past_key_values,
391
+ hidden_states=x if output_details["return_head"] else None,
392
+ latent_states=latent_states if output_details["return_latents"] else None,
393
+ attention_maps=attn_maps if output_details["return_attention"] else None, # type: ignore
394
+ stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad)
395
+ if output_details["return_stats"]
396
+ else None,
397
+ )
398
+
399
+ @torch._dynamo.disable(recursive=False) # type: ignore
400
+ def iterate_forward(
401
+ self,
402
+ input_embeds,
403
+ input_states,
404
+ freqs_cis,
405
+ block_idx,
406
+ mask,
407
+ past_key_values: Optional[Cache] = None,
408
+ num_steps: Optional[torch.Tensor] = None,
409
+ attn_maps: dict = {},
410
+ ):
411
+ x = xk = self.initialize_state(input_embeds) if input_states is None else input_states.clone()
412
+ if num_steps is None:
413
+ num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
414
+ elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
415
+ num_steps_no_grad, num_steps_with_grad = num_steps
416
+ else:
417
+ num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
418
+
419
+ with torch.no_grad():
420
+ # ultra annoying in ddp due to
421
+ # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
422
+ # for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
423
+ # and all parameters are always used
424
+ for step in range(num_steps_no_grad):
425
+ xk = x
426
+ x, block_idx, attn_maps = self.core_block_forward(
427
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
428
+ )
429
+
430
+ for step in range(num_steps_with_grad):
431
+ xk = x
432
+ x, block_idx, attn_maps = self.core_block_forward(
433
+ xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, attn_maps
434
+ )
435
+ return self.transformer.ln_f(x), num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx, attn_maps
436
+
437
+ def core_block_forward(
438
+ self,
439
+ x,
440
+ input_embeds,
441
+ freqs_cis,
442
+ mask,
443
+ past_key_values,
444
+ block_idx: Union[torch.Tensor, int],
445
+ attn_maps: dict = {},
446
+ ):
447
+ x = self.transformer.adapter(torch.cat([x, input_embeds], dim=-1))
448
+ for idx, block in enumerate(self.transformer.core_block, start=1):
449
+ x, attn_map = block(x, freqs_cis, block_idx + idx, mask, past_key_values, return_attn=len(attn_maps) > 0)
450
+ attn_maps[block_idx + idx] = attn_map
451
+ return x, block_idx + idx, attn_maps
452
+
453
+ @torch.no_grad()
454
+ def iterate_one_step(
455
+ self,
456
+ input_embeds,
457
+ input_states,
458
+ position_ids: Optional[torch.Tensor] = None,
459
+ cache_position: Optional[torch.Tensor] = None,
460
+ block_idx: Union[torch.Tensor, int] = 0,
461
+ attention_mask: Optional[Tensor] = None,
462
+ past_key_values: Optional[Cache] = None,
463
+ attn_maps: dict = {},
464
+ ):
465
+ if position_ids is None and cache_position is None:
466
+ freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]]
467
+ elif position_ids is not None:
468
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
469
+ elif cache_position is not None:
470
+ freqs_cis = self.freqs_cis[:, cache_position]
471
+ x, block_idx, attn_maps = self.core_block_forward(
472
+ input_states, input_embeds, freqs_cis, attention_mask, past_key_values, block_idx, attn_maps
473
+ )
474
+ return x, block_idx, attn_maps
475
+
476
+ def predict_from_latents(
477
+ self,
478
+ latents,
479
+ attention_mask: Optional[torch.Tensor] = None,
480
+ position_ids: Optional[torch.Tensor] = None,
481
+ cache_position: Optional[torch.Tensor] = None,
482
+ past_key_values: Optional[Cache] = None,
483
+ return_attn: bool = False,
484
+ attn_maps: dict = {},
485
+ ):
486
+ if position_ids is None and cache_position is None:
487
+ freqs_cis = self.freqs_cis[:, : latents.shape[1]]
488
+ elif position_ids is not None:
489
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
490
+ elif cache_position is not None:
491
+ freqs_cis = self.freqs_cis[:, cache_position]
492
+ x = self.transformer.ln_f(latents)
493
+ # Coda layers
494
+ for block_idx, block in enumerate(self.transformer.coda, start=1):
495
+ x, attn_map = block(x, freqs_cis, -block_idx, attention_mask, past_key_values)
496
+ attn_maps[block_idx] = attn_map
497
+ x = self.transformer.ln_f(x)
498
+
499
+ logits = self.lm_head(x).float()
500
+
501
+ return CausalLMOutputRecurrentLatents(
502
+ loss=torch.as_tensor(0.0),
503
+ log_ppl=torch.as_tensor(0.0),
504
+ logits=logits,
505
+ past_key_values=past_key_values,
506
+ attention_maps=attn_maps if len(attn_maps) > 0 else None,
507
+ )
508
+
509
+ def embed_inputs(
510
+ self,
511
+ input_ids: torch.Tensor,
512
+ attention_mask: Optional[torch.Tensor] = None,
513
+ position_ids: Optional[torch.Tensor] = None,
514
+ past_key_values: Optional[Cache] = None,
515
+ use_cache: bool = False,
516
+ cache_position: Optional[torch.Tensor] = None,
517
+ return_attn: bool = False,
518
+ **kwargs,
519
+ ) -> tuple[torch.Tensor, int, dict[int, Tensor]]:
520
+ # Support multiple position formats:
521
+ if position_ids is None and cache_position is None:
522
+ freqs_cis = self.freqs_cis[:, : input_ids.shape[1]]
523
+ elif position_ids is not None:
524
+ freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze())
525
+ elif cache_position is not None:
526
+ freqs_cis = self.freqs_cis[:, cache_position]
527
+
528
+ input_embeds = self.transformer.wte(input_ids)
529
+
530
+ if self.emb_scale != 1:
531
+ input_embeds = input_embeds * self.emb_scale # type: ignore
532
+
533
+ if use_cache and past_key_values is None:
534
+ past_key_values = HuginnDynamicCache()
535
+
536
+ # Non-recurrent prelude
537
+ attn_maps = {}
538
+ for block_idx, block in enumerate(self.transformer.prelude):
539
+ input_embeds, attn_maps = block(
540
+ input_embeds, freqs_cis, block_idx, attention_mask, past_key_values, return_attn
541
+ )
542
+ return input_embeds, block_idx, attn_maps
543
+
544
+ @torch._dynamo.disable(recursive=False) # type: ignore
545
+ def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
546
+ """Outputs are long tensors so that they can be passed through compiled functions"""
547
+ t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
548
+ s = self.config.mean_backprop_depth
549
+ if self.training:
550
+ sigma = 0.5
551
+ mu = math.log(t + s) - (sigma**2 / 2)
552
+ rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma)
553
+ p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1
554
+ n = torch.clamp(p - s, min=0)
555
+ k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
556
+ else:
557
+ n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0)
558
+
559
+ return n.to(dtype=torch.long), k.to(dtype=torch.long)
560
+
561
+ def initialize_state(self, input_embeds, deterministic: bool = False):
562
+ x = torch.randn_like(input_embeds)
563
+ std = self.config.init_values["std"]
564
+ torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std)
565
+ if self.emb_scale != 1:
566
+ x = x * self.emb_scale
567
+ return x if not deterministic else x.zero_()
568
+
569
+ def prepare_inputs_for_generation(
570
+ self,
571
+ input_ids: torch.LongTensor,
572
+ past_key_values: Optional[Cache] = None,
573
+ attention_mask: Optional[torch.LongTensor] = None,
574
+ inputs_embeds: Optional[torch.FloatTensor] = None,
575
+ cache_position: Optional[torch.LongTensor] = None,
576
+ **kwargs,
577
+ ):
578
+ model_inputs = {}
579
+ model_inputs["cache_position"] = cache_position
580
+ current_input_length = input_ids.shape[1]
581
+ if past_key_values is not None:
582
+ if type(past_key_values) == DynamicCache:
583
+ # Need to use custom cache, detect and replace HF dynamic cache if generate injects it
584
+ assert past_key_values.get_seq_length() == 0
585
+ past_key_values = HuginnDynamicCache()
586
+ model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None
587
+ input_ids = input_ids[:, cache_position] # type: ignore
588
+ model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format)
589
+
590
+ if cache_position is None:
591
+ position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device)
592
+ model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone(
593
+ memory_format=torch.contiguous_format
594
+ ) # some form of position_ids is a critical argument for the model to correctly apply rope!
595
+
596
+ # forward all other entries
597
+ for key, value in kwargs.items():
598
+ if key not in model_inputs:
599
+ model_inputs[key] = value
600
+ return model_inputs
601
+
602
+ @torch.no_grad()
603
+ def generate_minimal(
604
+ self,
605
+ input_ids: torch.LongTensor,
606
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
607
+ tokenizer=None,
608
+ streamer=None,
609
+ continuous_compute=False, # warm-start state / continuous CoT
610
+ cache_kwargs: dict = {},
611
+ **model_kwargs,
612
+ ) -> Union[torch.Tensor, dict[str, Any]]:
613
+ """Minimal single-sequence generation. Template for more complicated generate tasks"""
614
+ # Setup
615
+ if generation_config is None:
616
+ generation_config: GenerationConfig = self.generation_config # type: ignore
617
+ model_kwargs["past_key_values"] = HuginnDynamicCache(**cache_kwargs)
618
+ model_kwargs["use_cache"] = True
619
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
620
+ stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device)
621
+ if continuous_compute:
622
+ embedded_inputs, _, _ = self.embed_inputs(input_ids)
623
+ model_kwargs["input_states"] = self.initialize_state(embedded_inputs)
624
+ # Generate tokens
625
+ for _ in range(generation_config.max_length - input_ids.shape[1]):
626
+ # Forward pass
627
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
628
+ outputs = self(**model_inputs)
629
+ next_token_logits = outputs.logits[0, -1, :]
630
+ if continuous_compute:
631
+ current_last_latent = outputs.latent_states[:, -1:, :]
632
+
633
+ # Sample or select next token
634
+ if generation_config.do_sample:
635
+ if generation_config.temperature:
636
+ next_token_logits = next_token_logits / generation_config.temperature
637
+
638
+ probs = F.softmax(next_token_logits, dim=-1)
639
+
640
+ # Apply top_k
641
+ if generation_config.top_k:
642
+ top_k_probs, _ = torch.topk(probs, generation_config.top_k)
643
+ probs[probs < top_k_probs[-1]] = 0
644
+ # Apply top_p
645
+ if generation_config.top_p:
646
+ sorted_probs = torch.sort(probs, descending=True)[0]
647
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
648
+ probs[cumsum > generation_config.top_p] = 0
649
+ # Apply min_p
650
+ if generation_config.min_p:
651
+ probs[probs < generation_config.min_p * probs.max()] = 0
652
+
653
+ probs = probs / probs.sum()
654
+ next_token = torch.multinomial(probs, num_samples=1)
655
+ else:
656
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
657
+
658
+ input_ids = torch.cat([input_ids, next_token[None, :]], dim=-1) # type: ignore
659
+
660
+ if streamer:
661
+ streamer.put(next_token.cpu())
662
+
663
+ # Update model kwargs
664
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
665
+ if continuous_compute:
666
+ model_kwargs["input_states"] = current_last_latent
667
+
668
+ # Check if we hit a stop token
669
+ if stop_tokens is not None and next_token in stop_tokens:
670
+ break
671
+
672
+ if streamer:
673
+ streamer.end()
674
+
675
+ if generation_config.return_dict_in_generate:
676
+ return GenerateDecoderOnlyOutput(
677
+ sequences=input_ids,
678
+ scores=None,
679
+ logits=None,
680
+ attentions=None,
681
+ hidden_states=None,
682
+ past_key_values=model_kwargs.get("past_key_values"),
683
+ )
684
+ return input_ids
685
+
686
+ @torch.no_grad()
687
+ def generate_with_adaptive_compute(
688
+ self,
689
+ input_ids: torch.LongTensor,
690
+ generation_config: Optional[GenerationConfig] = None, # type: ignore
691
+ tokenizer=None,
692
+ streamer=None,
693
+ continuous_compute=False, # warm-start state / continuous CoT
694
+ latent_dampening=False,
695
+ criterion="entropy-diff",
696
+ cache_kwargs: dict = {},
697
+ **model_kwargs,
698
+ ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]:
699
+ """Minimal single-sequence generation. Template for more complicated generate tasks"""
700
+ # Setup
701
+ if generation_config is None:
702
+ generation_config: GenerationConfig = self.generation_config # type: ignore
703
+ model_kwargs["past_key_values"] = HuginnDynamicCache(**cache_kwargs)
704
+ model_kwargs["use_cache"] = True
705
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
706
+ stop_tokens = self._get_stops(generation_config, tokenizer).to(input_ids.device)
707
+ if continuous_compute:
708
+ embedded_inputs, _, _ = self.embed_inputs(input_ids)
709
+ current_last_latent = self.initialize_state(embedded_inputs)
710
+ compute_steps = []
711
+
712
+ # Generate tokens
713
+ for step in range(generation_config.max_length - input_ids.shape[1]):
714
+ # Adaptive compute forward
715
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
716
+ aux_inputs = {
717
+ k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs
718
+ }
719
+ embedded_inputs, block_idx, _ = self.embed_inputs(model_inputs["input_ids"], **aux_inputs)
720
+ if not continuous_compute:
721
+ current_latents = self.initialize_state(embedded_inputs, deterministic=False)
722
+ else:
723
+ current_latents = current_last_latent
724
+
725
+ # Prep criterions:
726
+ if criterion == "entropy-diff":
727
+ entropy = torch.tensor(100.0, device=input_ids.device)
728
+ elif criterion in ["latent-diff", "none"]:
729
+ pass
730
+ elif criterion == "kl":
731
+ V = self.config.padded_vocab_size
732
+ log_probs = (1 / V * torch.ones(V, device=input_ids.device)).log()
733
+ elif criterion == "argmax-stability":
734
+ stable_for_n_steps = 0
735
+ current_argmax = torch.tensor(-1, dtype=torch.long, device=input_ids.device)
736
+ else:
737
+ raise ValueError("Invalid adaptive compute strategy.")
738
+
739
+ all_latents = []
740
+ for compute_step in range(1, model_inputs["num_steps"]):
741
+ prev_latents = current_latents.clone()
742
+ current_latents, block_idx, _ = self.iterate_one_step(
743
+ embedded_inputs, current_latents, block_idx=block_idx, **aux_inputs
744
+ )
745
+ all_latents.append(current_latents if latent_dampening else None)
746
+ if compute_step > 1 and step > 0: # do not exit in prefill:
747
+ if criterion == "entropy-diff":
748
+ prev_entropy = entropy.clone()
749
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
750
+ probs = F.softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
751
+ entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1).mean()
752
+ entropy_diff = (entropy - prev_entropy).abs()
753
+ if entropy_diff < 1e-3:
754
+ compute_steps.append([compute_step, entropy_diff.item()])
755
+ break
756
+ elif criterion == "latent-diff":
757
+ norm_diff = (prev_latents - current_latents).norm()
758
+ if norm_diff < 1:
759
+ compute_steps.append([compute_step, norm_diff.item()])
760
+ break
761
+ elif criterion == "kl":
762
+ prev_log_probs = log_probs.clone()
763
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
764
+ log_probs = F.log_softmax(outputs.logits[:, -1, :], dim=-1) # type: ignore
765
+ kl = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1)
766
+ if kl < 2e-4:
767
+ compute_steps.append([compute_step, kl.item()])
768
+ break
769
+ elif criterion == "argmax-stability":
770
+ prev_argmax = current_argmax.clone()
771
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
772
+ current_argmax = outputs.logits[0, -1, :].argmax(dim=-1) # type: ignore
773
+ if current_argmax == prev_argmax:
774
+ stable_for_n_steps += 1
775
+ else:
776
+ stable_for_n_steps = 0
777
+ if stable_for_n_steps >= 10:
778
+ compute_steps.append([compute_step, stable_for_n_steps])
779
+ break
780
+ elif criterion == "none":
781
+ pass
782
+
783
+ else:
784
+ compute_steps.append([compute_step, float("NaN")])
785
+ if not latent_dampening:
786
+ outputs = self.predict_from_latents(current_latents, **aux_inputs)
787
+ else:
788
+ dampened_latents = torch.sum(torch.cat(all_latents, dim=0), dim=0, keepdim=True)
789
+ outputs = self.predict_from_latents(dampened_latents, **aux_inputs)
790
+
791
+ next_token_logits = outputs.logits[0, -1, :] # type: ignore
792
+ if continuous_compute: # Save last latent
793
+ current_last_latent = current_latents[:, -1:, :]
794
+
795
+ # Sample or select next token
796
+ if generation_config.do_sample:
797
+ if generation_config.temperature:
798
+ next_token_logits = next_token_logits / generation_config.temperature
799
+
800
+ probs = F.softmax(next_token_logits, dim=-1)
801
+ # Apply top_k
802
+ if generation_config.top_k:
803
+ top_k_probs, _ = torch.topk(probs, generation_config.top_k)
804
+ probs[probs < top_k_probs[-1]] = 0
805
+ # Apply top_p
806
+ if generation_config.top_p:
807
+ sorted_probs = torch.sort(probs, descending=True)[0]
808
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
809
+ probs[cumsum > generation_config.top_p] = 0
810
+ # Apply min_p
811
+ if generation_config.min_p:
812
+ probs[probs < generation_config.min_p * probs.max()] = 0
813
+
814
+ probs = probs / probs.sum()
815
+ next_token = torch.multinomial(probs, num_samples=1)
816
+ else:
817
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
818
+
819
+ input_ids = torch.cat([input_ids, next_token[None, :]], dim=-1) # type: ignore
820
+
821
+ if streamer:
822
+ streamer.put(next_token.cpu())
823
+
824
+ # Update model kwargs
825
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
826
+
827
+ # Check if we hit a stop token
828
+ if stop_tokens is not None and next_token in stop_tokens:
829
+ break
830
+
831
+ if streamer:
832
+ streamer.end()
833
+
834
+ if generation_config.return_dict_in_generate:
835
+ return GenerateDecoderOnlyOutput(
836
+ sequences=input_ids,
837
+ scores=compute_steps, # type: ignore
838
+ logits=None,
839
+ attentions=None,
840
+ hidden_states=None,
841
+ past_key_values=model_kwargs.get("past_key_values"),
842
+ )
843
+ return input_ids
844
+
845
+ def _get_stops(self, generation_config, tokenizer):
846
+ stop_tokens = set()
847
+ if generation_config.eos_token_id is not None:
848
+ stop_tokens.add(generation_config.eos_token_id)
849
+ if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings:
850
+ for s in generation_config.stop_strings:
851
+ token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0]
852
+ stop_tokens.add(token_id)
853
+ return torch.tensor(list(stop_tokens))
854
+
855
+ def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad):
856
+ probs = torch.softmax(logits.float(), dim=-1)
857
+ prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1)
858
+ residual_diff = (x - latent_states).norm(dim=-1)
859
+ rel_residual = residual_diff / latent_states.norm(dim=-1)
860
+ stats = {
861
+ "entropy": prob_entropy,
862
+ "residual_diff": residual_diff,
863
+ "rel_residual": rel_residual,
864
+ "num_steps_no_grad": num_steps_no_grad,
865
+ "num_steps_with_grad": num_steps_with_grad,
866
+ }
867
+ return stats
868
+
869
+
870
+ #################################### Utils #######################################################################
871
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, condense_ratio: int = 1):
872
+ with torch.autocast("cuda", enabled=False):
873
+ inv_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
874
+ t = torch.arange(end, dtype=torch.float32, device=inv_freqs.device) / condense_ratio
875
+ freqs = torch.outer(t, inv_freqs).float()
876
+ return torch.stack([torch.cos(freqs)[None, :, None, :], torch.sin(freqs)[None, :, None, :]], dim=4)
877
+ # equivalent to
878
+ # freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
879
+ # cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
880
+
881
+
882
+ def apply_rotary_emb_complex_like(q: Tensor, k: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
883
+ with torch.autocast("cuda", enabled=False):
884
+ qk_r2 = torch.cat([q, k], dim=2).unflatten(dim=-1, sizes=(-1, 2)).float() # cast to float32 for smooth skin
885
+ rotated_qk_r2 = torch.stack(
886
+ [
887
+ qk_r2[..., 0] * freqs_cis[..., 0] - qk_r2[..., 1] * freqs_cis[..., 1],
888
+ qk_r2[..., 1] * freqs_cis[..., 0] + qk_r2[..., 0] * freqs_cis[..., 1],
889
+ ],
890
+ -1,
891
+ ).flatten(3)
892
+ rotated_qk = rotated_qk_r2
893
+ return torch.split(rotated_qk.type_as(q), q.shape[2], dim=2) # type: ignore
894
+
895
+
896
+ #################################### HF registration ############################################################
897
+
898
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
899
+
900
+ # New
901
+ RavenConfig.register_for_auto_class()
902
+
903
+ RavenForCausalLM.register_for_auto_class("AutoModel")
904
+ RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM")
905
+
906
+ # Old?
907
+ AutoConfig.register("huginn_raven", RavenConfig)
908
+ AutoModel.register(RavenConfig, RavenForCausalLM)
909
+ AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM)