teragron commited on
Commit
233119d
1 Parent(s): 10f7f48

Upload 21 files

Browse files
Files changed (21) hide show
  1. assets/llama_cute.jpg +0 -0
  2. build_msvc.bat +1 -0
  3. configurator.py +47 -0
  4. doc/stories260K.md +58 -0
  5. doc/train_llama_tokenizer.md +99 -0
  6. export.py +567 -0
  7. model.py +343 -0
  8. requirements.txt +7 -0
  9. run.c +973 -0
  10. run.ipynb +130 -0
  11. runq.c +1092 -0
  12. sample.py +79 -0
  13. test.c +84 -0
  14. test_all.py +89 -0
  15. tinystories.py +281 -0
  16. tokenizer.bin +3 -0
  17. tokenizer.model +3 -0
  18. tokenizer.py +78 -0
  19. train.py +343 -0
  20. win.c +180 -0
  21. win.h +69 -0
assets/llama_cute.jpg ADDED
build_msvc.bat ADDED
@@ -0,0 +1 @@
 
 
1
+ cl.exe /fp:fast /Ox /openmp /I. run.c win.c
configurator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import sys
18
+ from ast import literal_eval
19
+
20
+ for arg in sys.argv[1:]:
21
+ if '=' not in arg:
22
+ # assume it's the name of a config file
23
+ assert not arg.startswith('--')
24
+ config_file = arg
25
+ print(f"Overriding config with {config_file}:")
26
+ with open(config_file) as f:
27
+ print(f.read())
28
+ exec(open(config_file).read())
29
+ else:
30
+ # assume it's a --key=value argument
31
+ assert arg.startswith('--')
32
+ key, val = arg.split('=')
33
+ key = key[2:]
34
+ if key in globals():
35
+ try:
36
+ # attempt to eval it it (e.g. if bool, number, or etc)
37
+ attempt = literal_eval(val)
38
+ except (SyntaxError, ValueError):
39
+ # if that goes wrong, just use the string
40
+ attempt = val
41
+ # ensure the types match ok
42
+ assert type(attempt) == type(globals()[key])
43
+ # cross fingers
44
+ print(f"Overriding: {key} = {attempt}")
45
+ globals()[key] = attempt
46
+ else:
47
+ raise ValueError(f"Unknown config key: {key}")
doc/stories260K.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stories260K
2
+
3
+ [Stories260K huggginface link](https://huggingface.co/karpathy/tinyllamas)
4
+
5
+ The 260K model is a tiny model used for testing, and was trained as follows:
6
+
7
+ ```
8
+ python train.py \
9
+ --out_dir="outmini" \
10
+ --batch_size=128 \
11
+ --max_seq_len=512 \
12
+ --gradient_accumulation_steps=1 \
13
+ --vocab_source="custom" \
14
+ --vocab_size=512 \
15
+ --dim=64 \
16
+ --n_layers=5 \
17
+ --n_heads=8 \
18
+ --n_kv_heads=4 \
19
+ --multiple_of=4 \
20
+ --learning_rate=1e-3 \
21
+ --dropout=0.05 \
22
+ --weight_decay=0.01 \
23
+ --max_iters=100000 \
24
+ --beta2=0.99 \
25
+ --warmup_iters=1000 \
26
+ --eval_interval=2000 \
27
+ --eval_iters=100 \
28
+ --compile=True
29
+ ```
30
+
31
+ You'll notice that `n_kv_heads` is 4 while `n_heads` is 8, so two heads at a time share their key,value projections, i.e. this model is 2X multiquery. You'll also notice that we're using a custom tokenizer with 512 tokens. The model trained for ~10 minutes (?) on my A100 and achieves validation loss of 1.2968.
32
+
33
+ Sampling this model at temperature 0.0 (i.e. deterministic greedy argmax sampling) gives:
34
+
35
+ ```
36
+ $ ./run stories260K/stories260K.bin -z stories260K/tok512.bin -t 0.0
37
+ Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big, red ball. She wanted to play with it, but it was too high.
38
+ Lily's mom said, "Lily, let's go to the park." Lily was sad and didn't know what to do. She said, "I want to play with your ball, but I can't find it."
39
+ Lily was sad and didn't know what to do. She said, "I'm sorry, Lily. I didn't know what to do."
40
+ Lily didn't want to help her mom, so she said, "I'm sorry, mom. I didn't know what to do." Her mom said, "Don't worry, Lily. We can help you.
41
+ ```
42
+
43
+ You can reproduce the same in Python by running `sample.py`:
44
+
45
+ ```
46
+ $ python sample.py --checkpoint=stories260K/stories260K.pt --tokenizer=stories260K/tok512.model --temperature=0.0 --max_new_tokens=257
47
+ ```
48
+
49
+ I hardcoded max tokens to be 257 manually because the `sample.py` script doesn't currently terminate on the special BOS token like the run.c script does. Sampling at 1.0 with topp of 0.9 gives a bit more reasonable samples:
50
+
51
+ ```
52
+ $ ./run stories260K/stories260K.bin -z stories260K/tok512.bin -t 1.0 -p 0.9 -s 133742
53
+ Once upon a time, there was a little boy named Timmy. Timmy loved to play with his toys and eat sandwiches. One day, Timmy's mom told him it was time to rest for a while. Timmy's friend Billy came over and took him a down.
54
+ Timmy's mom saw that Timmy was sad, but Timmy said, "I didn't understand what is it! We need to find some leafs." Timmy thought about it and took a deep breath on a spoon. He hoped it was important to be kind and continued to find its image next time.
55
+ After they finished getting, Timmy's dad came up to his house and promised to help Timmy.
56
+ ```
57
+
58
+ Hey you can't expect too much from a 260K parameter model. I'm even mildly shocked we get this far :D
doc/train_llama_tokenizer.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # training llama tokenizer
2
+
3
+ How does Meta train their sentencepiece tokenizer? You can print the config as follows:
4
+
5
+ ```python
6
+ import sentencepiece.sentencepiece_model_pb2
7
+ mp = sentencepiece.sentencepiece_model_pb2.ModelProto()
8
+ mp.ParseFromString(open("tokenizer.model", "rb").read())
9
+ print(mp.trainer_spec)
10
+ print(mp.normalizer_spec)
11
+ ```
12
+
13
+ this gives:
14
+
15
+ ```
16
+ trainer_spec {
17
+ input: "/large_experiments/theorem/datasets/MERGED/all.test1.merged"
18
+ model_prefix: "spm_model_32k_200M_charcov099995_allowWSO__v2"
19
+ model_type: BPE
20
+ vocab_size: 32000
21
+ self_test_sample_size: 0
22
+ input_format: "text"
23
+ character_coverage: 0.9999499917030334
24
+ input_sentence_size: 200000000
25
+ seed_sentencepiece_size: 1000000
26
+ shrinking_factor: 0.75
27
+ num_threads: 80
28
+ num_sub_iterations: 2
29
+ max_sentence_length: 4192
30
+ shuffle_input_sentence: true
31
+ max_sentencepiece_length: 16
32
+ split_by_unicode_script: true
33
+ split_by_whitespace: true
34
+ split_by_number: true
35
+ treat_whitespace_as_suffix: false
36
+ split_digits: true
37
+ allow_whitespace_only_pieces: true
38
+ vocabulary_output_piece_score: true
39
+ hard_vocab_limit: true
40
+ use_all_vocab: false
41
+ byte_fallback: true
42
+ required_chars: ""
43
+ unk_id: 0
44
+ bos_id: 1
45
+ eos_id: 2
46
+ pad_id: -1
47
+ unk_surface: " \342\201\207 "
48
+ unk_piece: "<unk>"
49
+ bos_piece: "<s>"
50
+ eos_piece: "</s>"
51
+ pad_piece: "<pad>"
52
+ train_extremely_large_corpus: false
53
+ enable_differential_privacy: false
54
+ differential_privacy_noise_level: 0.0
55
+ differential_privacy_clipping_threshold: 0
56
+ }
57
+ normalizer_spec {
58
+ name: "identity"
59
+ precompiled_charsmap: ""
60
+ add_dummy_prefix: true
61
+ remove_extra_whitespaces: false
62
+ normalization_rule_tsv: ""
63
+ }
64
+ ```
65
+
66
+ We can use the sentencepiece spm_train to train the same models, but optionally smaller. Here are their [options docs](https://github.com/google/sentencepiece/blob/master/doc/options.md) we can refer to. It's not much but it helps.
67
+
68
+ We'll depart on one setting, I recommend changing `character_coverage` -> 1.0. We also want to make sure to note the following important settings that come up in the paper and are not necessarily the default sentencepiece settings:
69
+
70
+ ```
71
+ --split-digits = true
72
+ --allow_whitespace_only_pieces = true
73
+ --byte_fallback = true
74
+ --normalization_rule_name = identity
75
+ ```
76
+
77
+ With this in mind we can train a sentencepiece vocab in what I believe is probably the same to how Meta trained theirs as:
78
+
79
+ ```
80
+ spm_train --input="$input" \
81
+ --model_prefix="$model_prefix" \
82
+ --model_type=bpe \
83
+ --vocab_size="$vocab_size" \
84
+ --self_test_sample_size=0 \
85
+ --input_format="text" \
86
+ --character_coverage=1.0 \
87
+ --num_threads="$(nproc)" \
88
+ --split_digits=true \
89
+ --allow_whitespace_only_pieces=true \
90
+ --byte_fallback=true \
91
+ --unk_surface=" \342\201\207 " \
92
+ --normalization_rule_name=identity \
93
+ ```
94
+
95
+ Where $input is the input file, $model_prefix is the output path prefix, vocab_size is the desired vocab, and we're by default taking over the CPU resources of the machine.
96
+
97
+ Lastly note that sentencepiece is weird and expects "sentences" delimited by newlines as the input. You can't just put in a massive block of text. And they have a hyperparameter that constols the maximum size of a "sentence". Fwiw I really dislike this design choice around a weird concept of a "sentence". It should just be block of text with no assumptions. But here we are.
98
+
99
+ Look into the file `tinystories.py` where we train the vocab in the same way, but using Python bindings instead.
export.py ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script has functions and utilties for model export.
3
+ Basically, we have a bunch of versions of the model, and we
4
+ want to export them to .bin files to be read from and inferenced in C.
5
+
6
+ Among the "input" versions of PyTorch files/models:
7
+ - Official Llama 2 weights released by Meta
8
+ - Huggingface weights available on the hub
9
+ - llama2.c (this repo) trained models
10
+
11
+ Among the "output" versions of .bin files:
12
+ - v0: Legacy files of the original llama2.c repo (will eventually be DEPRECATED)
13
+ - v1-vN: Improved .bin files with a proper header, cache alignment, etc.
14
+
15
+ This script aspires to provide all of these conversions.
16
+ """
17
+ import os
18
+ import gzip
19
+ import shutil
20
+ import struct
21
+ import argparse
22
+ import json
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch import nn
28
+
29
+ from model import ModelArgs, Transformer
30
+
31
+ # -----------------------------------------------------------------------------
32
+ # common utilities
33
+
34
+ def serialize_fp32(file, tensor):
35
+ """ writes one fp32 tensor to file that is open in wb mode """
36
+ d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
37
+ b = struct.pack(f'{len(d)}f', *d)
38
+ file.write(b)
39
+
40
+ def serialize_int8(file, tensor):
41
+ """ writes one int8 tensor to file that is open in wb mode """
42
+ d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
43
+ b = struct.pack(f'{len(d)}b', *d)
44
+ file.write(b)
45
+
46
+ def quantize_q80(w, group_size):
47
+ """
48
+ takes a tensor and returns the Q8_0 quantized version
49
+ i.e. symmetric quantization into int8, range [-127,127]
50
+ """
51
+ assert w.numel() % group_size == 0
52
+ ori_shape = w.shape
53
+ w = w.float() # convert to float32
54
+ w = w.reshape(-1, group_size)
55
+ # find the max in each group
56
+ wmax = torch.abs(w).max(dim=1).values
57
+ # calculate the scaling factor such that float = quant * scale
58
+ scale = wmax / 127.0
59
+ # scale into range [-127, 127]
60
+ quant = w / scale[:,None]
61
+ # round to nearest integer
62
+ int8val = torch.round(quant).to(torch.int8)
63
+ # dequantize by rescaling
64
+ fp32val = (int8val.float() * scale[:,None]).view(-1)
65
+ fp32valr = fp32val.reshape(-1, group_size)
66
+ # calculate the max error in each group
67
+ err = torch.abs(fp32valr - w).max(dim=1).values
68
+ # find the max error across all groups
69
+ maxerr = err.max().item()
70
+ return int8val, scale, maxerr
71
+
72
+ # -----------------------------------------------------------------------------
73
+ # legacy
74
+
75
+ def legacy_export(model, filepath):
76
+ """ Original export of llama2.c bin files, i.e. version v0 """
77
+ out_file = open(filepath, 'wb')
78
+
79
+ # first write out the header
80
+ hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
81
+ p = model.params
82
+ shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
83
+ # legacy format uses negative/positive vocab size as a shared classifier flag
84
+ if not shared_classifier:
85
+ p.vocab_size = -p.vocab_size
86
+ n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
87
+ header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
88
+ n_kv_heads, p.vocab_size, p.max_seq_len)
89
+ out_file.write(header)
90
+
91
+ # next write out the embedding weights
92
+ serialize_fp32(out_file, model.tok_embeddings.weight)
93
+
94
+ # now all the layers
95
+ # attention weights
96
+ for layer in model.layers:
97
+ serialize_fp32(out_file, layer.attention_norm.weight)
98
+ for layer in model.layers:
99
+ serialize_fp32(out_file, layer.attention.wq.weight)
100
+ for layer in model.layers:
101
+ serialize_fp32(out_file, layer.attention.wk.weight)
102
+ for layer in model.layers:
103
+ serialize_fp32(out_file, layer.attention.wv.weight)
104
+ for layer in model.layers:
105
+ serialize_fp32(out_file, layer.attention.wo.weight)
106
+ # ffn weights
107
+ for layer in model.layers:
108
+ serialize_fp32(out_file, layer.ffn_norm.weight)
109
+ for layer in model.layers:
110
+ serialize_fp32(out_file, layer.feed_forward.w1.weight)
111
+ for layer in model.layers:
112
+ serialize_fp32(out_file, layer.feed_forward.w2.weight)
113
+ for layer in model.layers:
114
+ serialize_fp32(out_file, layer.feed_forward.w3.weight)
115
+ # final rmsnorm
116
+ serialize_fp32(out_file, model.norm.weight)
117
+ # freqs_cis
118
+ serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len])
119
+ serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len])
120
+
121
+ # final classifier weights
122
+ if not shared_classifier:
123
+ serialize_fp32(out_file, model.output.weight)
124
+
125
+ # write to binary file
126
+ out_file.close()
127
+ print(f"wrote {filepath}")
128
+
129
+ # -----------------------------------------------------------------------------
130
+ # new version
131
+
132
+ def version1_export(model, filepath):
133
+ """
134
+ Export the model weights in full float32 .bin file to be read from C.
135
+ This is same as legacy_export, but with a proper header.
136
+ """
137
+ version = 1
138
+
139
+ out_file = open(filepath, 'wb')
140
+ # first write out the header. the header will be 256 bytes
141
+ # 1) write magic, which will be uint32 of "ak42" in ASCII
142
+ out_file.write(struct.pack('I', 0x616b3432))
143
+ # 2) write version, which will be int
144
+ out_file.write(struct.pack('i', version))
145
+ # 3) write the params, which will be 7 ints
146
+ p = model.params
147
+ hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
148
+ n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
149
+ header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
150
+ n_kv_heads, p.vocab_size, p.max_seq_len)
151
+ out_file.write(header)
152
+ # 4) write some other flags
153
+ shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
154
+ out_file.write(struct.pack('B', int(shared_classifier)))
155
+ pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
156
+ assert pad >= 0
157
+ out_file.write(b'\0' * pad)
158
+
159
+ # now let's write out all the params
160
+ weights = [
161
+ *[layer.attention_norm.weight for layer in model.layers],
162
+ *[layer.ffn_norm.weight for layer in model.layers],
163
+ model.norm.weight,
164
+ model.tok_embeddings.weight,
165
+ *[layer.attention.wq.weight for layer in model.layers],
166
+ *[layer.attention.wk.weight for layer in model.layers],
167
+ *[layer.attention.wv.weight for layer in model.layers],
168
+ *[layer.attention.wo.weight for layer in model.layers],
169
+ *[layer.feed_forward.w1.weight for layer in model.layers],
170
+ *[layer.feed_forward.w2.weight for layer in model.layers],
171
+ *[layer.feed_forward.w3.weight for layer in model.layers],
172
+ ]
173
+ if not shared_classifier:
174
+ weights.append(model.output.weight)
175
+ for w in weights:
176
+ serialize_fp32(out_file, w)
177
+
178
+ # write to binary file
179
+ out_file.close()
180
+ print(f"wrote {filepath}")
181
+
182
+ def version2_export(model, filepath, group_size=64):
183
+ """
184
+ Export the model weights in Q8_0 into .bin file to be read from C.
185
+ That is:
186
+ - quantize all weights to symmetric int8, in range [-127, 127]
187
+ - all other tensors (the rmsnorm params) are kept and exported in fp32
188
+ - quantization is done in groups of group_size to reduce the effects of any outliers
189
+ """
190
+ version = 2
191
+
192
+ # let's first do some validation for this export type
193
+ while model.params.dim % group_size != 0:
194
+ group_size //= 2
195
+ print(f"BACKOFF: reducing group size to {group_size} to fit hidden_dim")
196
+ weights = [
197
+ model.tok_embeddings.weight,
198
+ *[layer.attention.wq.weight for layer in model.layers],
199
+ *[layer.attention.wk.weight for layer in model.layers],
200
+ *[layer.attention.wv.weight for layer in model.layers],
201
+ *[layer.attention.wo.weight for layer in model.layers],
202
+ *[layer.feed_forward.w1.weight for layer in model.layers],
203
+ *[layer.feed_forward.w2.weight for layer in model.layers],
204
+ *[layer.feed_forward.w3.weight for layer in model.layers],
205
+ ]
206
+ shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight)
207
+ if not shared_classifier:
208
+ weights.append(model.output.weight)
209
+ for w in weights:
210
+ assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}"
211
+
212
+ # write
213
+ out_file = open(filepath, 'wb')
214
+ # first write out the header. the header will be 256 bytes
215
+ # 1) write magic, which will be uint32 of "ak42" in ASCII
216
+ out_file.write(struct.pack('I', 0x616b3432))
217
+ # 2) write version, which will be int
218
+ out_file.write(struct.pack('i', version))
219
+ # 3) write the params, which will be 7 ints
220
+ p = model.params
221
+ hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0]
222
+ n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
223
+ header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
224
+ n_kv_heads, p.vocab_size, p.max_seq_len)
225
+ out_file.write(header)
226
+ # 4) write some other flags
227
+ out_file.write(struct.pack('B', int(shared_classifier)))
228
+ out_file.write(struct.pack('i', group_size)) # group size used for quantization
229
+ pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos
230
+ assert pad >= 0
231
+ out_file.write(b'\0' * pad)
232
+ # now that the header is done, let's write out the model
233
+
234
+ # first let's write out all the params that we are keeping in fp32: the norms
235
+ for layer in model.layers: # attention norms
236
+ serialize_fp32(out_file, layer.attention_norm.weight)
237
+ for layer in model.layers: # MLP norms
238
+ serialize_fp32(out_file, layer.ffn_norm.weight)
239
+ serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm
240
+
241
+ # now let's write out all the params that we are quantizing to Q8_0
242
+ # note we skip classifier weights, which are shared with the embedding
243
+ ew = []
244
+ for i, w in enumerate(weights):
245
+ # quantize this weight
246
+ q, s, err = quantize_q80(w, group_size)
247
+ # save the int8 weights to file
248
+ serialize_int8(out_file, q) # save the tensor in int8
249
+ serialize_fp32(out_file, s) # save scale factors
250
+ # logging
251
+ ew.append((err, w.shape))
252
+ print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}")
253
+
254
+ # print the highest error across all weights, should be very small, e.g. O(~0.001)
255
+ ew.sort(reverse=True)
256
+ print(f"max quantization group error across all weights: {ew[0][0]}")
257
+
258
+ # write to binary file
259
+ out_file.close()
260
+ print(f"wrote {filepath}")
261
+
262
+ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32):
263
+ """ Generate the pytorch_model.bin state_dict and config.json for HuggingFace """
264
+
265
+ try:
266
+ from transformers.models.llama.configuration_llama import LlamaConfig
267
+ except ImportError:
268
+ print("Error: transformers package is required to load huggingface models")
269
+ print("Please run `pip install transformers` to install it")
270
+ return None
271
+
272
+ # Generate LlamaModel state_dict
273
+ hf_state_dict = {}
274
+
275
+ # Sometimes we have repeated key values for the heads
276
+ dim = llama_model.params.dim
277
+ num_key_value_heads = llama_model.params.n_kv_heads
278
+ n_rep = llama_model.params.n_heads // num_key_value_heads
279
+ key_value_dim = dim // n_rep
280
+
281
+ # HuggingFace needs the weights permuted.
282
+ # See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122
283
+ def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim):
284
+ return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
285
+
286
+ # Transfer weights from llama model to the HF state dictionary format
287
+ hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype)
288
+ hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype)
289
+
290
+ # Add each layer's weights to the HF state dictionary
291
+ for i, layer in enumerate(llama_model.layers):
292
+ layer_id = layer.layer_id
293
+ hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype)
294
+ hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype)
295
+ hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype)
296
+ hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype)
297
+ hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype)
298
+ hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype)
299
+ hf_state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype)
300
+ hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype)
301
+ hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype)
302
+
303
+ # llama2.c usually uses tied weights -> reference the embed_tokens.weights instead
304
+ hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight']
305
+
306
+ # We check that the embeddings are tied, else use manual output weights
307
+ _embeddings_are_tied: bool = torch.equal(llama_model.tok_embeddings.weight, llama_model.output.weight)
308
+ if not _embeddings_are_tied:
309
+ hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype)
310
+
311
+
312
+ # Generate LlamaConfig (seen in transformers.models.llama.configuration_llama)
313
+
314
+ # Extract necessary attributes from llama.c model
315
+ vocab_size = llama_model.params.vocab_size
316
+ hidden_size = llama_model.params.dim
317
+ intermediate_size = llama_model.layers[0].feed_forward.w1.weight.shape[0]
318
+ num_hidden_layers = llama_model.params.n_layers
319
+ num_attention_heads = llama_model.params.n_heads
320
+ num_key_value_heads = llama_model.params.n_kv_heads
321
+ max_position_embeddings = llama_model.params.max_seq_len
322
+ rms_norm_eps = llama_model.params.norm_eps
323
+
324
+ # TODO check values for:
325
+ # pretraining_tp, initializer_range, use_cache,
326
+ # rope_theta, and rope_scaling.
327
+
328
+ config = LlamaConfig(
329
+ vocab_size=vocab_size,
330
+ hidden_size=hidden_size,
331
+ intermediate_size=intermediate_size,
332
+ num_hidden_layers=num_hidden_layers,
333
+ num_attention_heads=num_attention_heads,
334
+ num_key_value_heads=num_key_value_heads,
335
+ max_position_embeddings=max_position_embeddings,
336
+ rms_norm_eps=rms_norm_eps,
337
+ tie_word_embeddings=_embeddings_are_tied,
338
+ # Manual
339
+ architectures=["LlamaForCausalLM"],
340
+ hidden_act="silu",
341
+ )
342
+
343
+
344
+ # Save files in directory filepath
345
+ # First make the directory if it doesn't exist
346
+ os.makedirs(filepath, exist_ok=True)
347
+
348
+ # Save the state dictionary in .bin format, and config as .json
349
+ torch.save(hf_state_dict, os.path.join(filepath, "pytorch_model.bin"))
350
+ config.save_pretrained(filepath)
351
+
352
+
353
+ # -----------------------------------------------------------------------------
354
+ # Load / import functions
355
+
356
+ def load_checkpoint(checkpoint):
357
+
358
+ # load the provided model checkpoint
359
+ checkpoint_dict = torch.load(checkpoint, map_location='cpu')
360
+ gptconf = ModelArgs(**checkpoint_dict['model_args'])
361
+ model = Transformer(gptconf)
362
+ state_dict = checkpoint_dict['model']
363
+ unwanted_prefix = '_orig_mod.'
364
+ for k,v in list(state_dict.items()):
365
+ if k.startswith(unwanted_prefix):
366
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
367
+ model.load_state_dict(state_dict, strict=False)
368
+ model.eval()
369
+ return model
370
+
371
+ def load_meta_model(model_path):
372
+ params_path = os.path.join(model_path, 'params.json')
373
+ with open(params_path) as f:
374
+ params = json.load(f)
375
+ print(params)
376
+
377
+ model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth')))
378
+ models = [torch.load(p, map_location='cpu') for p in model_paths]
379
+
380
+ def concat_weights(models):
381
+ state_dict = {}
382
+ for name in list(models[0]):
383
+ tensors = [model[name] for model in models]
384
+ if len(tensors) == 1 or len(tensors[0].shape) == 1:
385
+ state_dict[name] = tensors[0]
386
+ continue
387
+ is_axis_1 = (
388
+ name.startswith('tok_embeddings.')
389
+ or name.endswith('.attention.wo.weight')
390
+ or name.endswith('.feed_forward.w2.weight')
391
+ )
392
+ axis = 1 if is_axis_1 else 0
393
+ state_dict[name] = torch.cat(tensors, dim=axis)
394
+ for model in models:
395
+ del model[name]
396
+ return state_dict
397
+
398
+ state_dict = concat_weights(models)
399
+ del models
400
+
401
+ # set ModelArgs
402
+ config = ModelArgs()
403
+ config.dim = params["dim"]
404
+ config.n_layers = params["n_layers"]
405
+ config.n_heads = params["n_heads"]
406
+ config.n_kv_heads = params.get('n_kv_heads') or params['n_heads']
407
+ config.multiple_of = params["multiple_of"]
408
+ config.norm_eps = params["norm_eps"]
409
+
410
+ config.vocab_size = state_dict['tok_embeddings.weight'].shape[0]
411
+ config.max_seq_len = 2048
412
+
413
+
414
+ # create a new Transformer object and set weights
415
+ model = Transformer(config)
416
+
417
+ model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight'])
418
+ model.norm.weight = nn.Parameter(state_dict['norm.weight'])
419
+
420
+ for layer in model.layers:
421
+ i = layer.layer_id
422
+ layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight'])
423
+ layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight'])
424
+ layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight'])
425
+ layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight'])
426
+ layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight'])
427
+ layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight'])
428
+ layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight'])
429
+ layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight'])
430
+ layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight'])
431
+
432
+ # final classifier
433
+ model.output.weight = nn.Parameter(state_dict['output.weight'])
434
+ model.eval()
435
+ return model
436
+
437
+ def load_hf_model(model_path):
438
+
439
+ try:
440
+ from transformers import AutoModelForCausalLM
441
+ except ImportError:
442
+ print("Error: transformers package is required to load huggingface models")
443
+ print("Please run `pip install transformers` to install it")
444
+ return None
445
+
446
+ # load HF model
447
+ hf_model = AutoModelForCausalLM.from_pretrained(model_path)
448
+ hf_dict = hf_model.state_dict()
449
+
450
+ # convert LlamaConfig to ModelArgs
451
+ config = ModelArgs()
452
+ config.dim = hf_model.config.hidden_size
453
+ config.n_layers = hf_model.config.num_hidden_layers
454
+ config.n_heads = hf_model.config.num_attention_heads
455
+ config.n_kv_heads = hf_model.config.num_attention_heads
456
+ config.vocab_size = hf_model.config.vocab_size
457
+ config.hidden_dim = hf_model.config.intermediate_size
458
+ config.norm_eps = hf_model.config.rms_norm_eps
459
+ config.max_seq_len = hf_model.config.max_position_embeddings
460
+
461
+ # create a new Transformer object and set weights
462
+ model = Transformer(config)
463
+
464
+ model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight'])
465
+ model.norm.weight = nn.Parameter(hf_dict['model.norm.weight'])
466
+
467
+ # huggingface permutes WQ and WK, this function reverses it
468
+ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim):
469
+ return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2)
470
+
471
+ for layer in model.layers:
472
+ i = layer.layer_id
473
+ layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
474
+ layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
475
+ layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
476
+ layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
477
+ layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
478
+ layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
479
+ layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight'])
480
+ layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight'])
481
+ layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight'])
482
+
483
+ # final classifier
484
+ model.output.weight = nn.Parameter(hf_dict['lm_head.weight'])
485
+ model.eval()
486
+ return model
487
+
488
+
489
+ # -----------------------------------------------------------------------------
490
+ # API entrypoint
491
+
492
+ def model_export(model, filepath, version, dtype=torch.float32):
493
+ """
494
+ Versions docs:
495
+ v-1:huggingface export, i.e. intended for use outside of this repo, in HF
496
+ v0: legacy llama2.c float format, DEPRECATED
497
+ v1: float32 export
498
+ v2: int8 quantized Q8_0 export, similar to llama.cpp, in groups
499
+ # TODO: add dtype export support for other versions (?)
500
+ """
501
+ if version == 0:
502
+ legacy_export(model, filepath)
503
+ elif version == 1:
504
+ version1_export(model, filepath)
505
+ elif version == 2:
506
+ version2_export(model, filepath)
507
+ elif version == -1:
508
+ hf_export(model, filepath, dtype)
509
+ else:
510
+ raise ValueError(f"unknown version {version}")
511
+
512
+ def torchscript_export(model, filepath, zero_params=False, gzip_output=False):
513
+ """
514
+ (This was submitted via a PR earlier. Leaving it here, but "orphaned" for now)
515
+ Saves the model as a TorchScript.
516
+ The resulting file can be loaded in C++ code and then used for training or
517
+ inference with:
518
+ #include <torch/script.h>
519
+ torch::jit::Module module = torch::jit::load("model.pt")
520
+ Note that the serialized model includes the initial parameters and with the default
521
+ ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
522
+ the model parameters separately you can zero out the parameters before saving it and
523
+ it will gzip down to 780K.
524
+ """
525
+
526
+ # If requested zero params before saving the model. This is useful in
527
+ # conjunction with gzip_output.
528
+ if zero_params:
529
+ for p in model.parameters():
530
+ p.detach().zero_()
531
+
532
+ torch.jit.save(torch.jit.script(model), filepath)
533
+
534
+ if gzip_output:
535
+ with open(filepath, "rb") as f_in:
536
+ with gzip.open(f"{filepath}.gz", "wb") as f_out:
537
+ shutil.copyfileobj(f_in, f_out)
538
+ os.unlink(filepath)
539
+
540
+ # -----------------------------------------------------------------------------
541
+ # CLI entrypoint
542
+
543
+ if __name__ == "__main__":
544
+
545
+ parser = argparse.ArgumentParser()
546
+ parser.add_argument("filepath", type=str, help="the output filepath")
547
+ parser.add_argument("--version", default=0, type=int, help="the version to export with")
548
+ parser.add_argument("--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32")
549
+ group = parser.add_mutually_exclusive_group(required=True)
550
+ group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file")
551
+ group.add_argument("--meta-llama", type=str, help="meta llama model path")
552
+ group.add_argument("--hf", type=str, help="huggingface model path")
553
+ args = parser.parse_args()
554
+ dtype = {"fp16": torch.float16, "fp32": torch.float32}[args.dtype]
555
+
556
+ if args.checkpoint:
557
+ model = load_checkpoint(args.checkpoint)
558
+ elif args.meta_llama:
559
+ model = load_meta_model(args.meta_llama)
560
+ elif args.hf:
561
+ model = load_hf_model(args.hf)
562
+
563
+ if model is None:
564
+ parser.error("Can't load input model!")
565
+
566
+ # export
567
+ model_export(model, args.filepath, args.version, args.dtype)
model.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import struct
3
+ import inspect
4
+ from dataclasses import dataclass
5
+ from typing import Any, Optional, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ @dataclass
13
+ class ModelArgs:
14
+ # default hyperparameters for the Llama 7B model
15
+ dim: int = 4096
16
+ n_layers: int = 32
17
+ n_heads: int = 32
18
+ n_kv_heads: Optional[int] = None
19
+ vocab_size: int = 32000
20
+ hidden_dim: Optional[int] = None
21
+ multiple_of: int = 256 # MLP hidden layer size will be multiple of
22
+ norm_eps: float = 1e-5
23
+ max_seq_len: int = 2048
24
+ dropout: float = 0.0
25
+
26
+
27
+ class RMSNorm(torch.nn.Module):
28
+ def __init__(self, dim: int, eps: float):
29
+ super().__init__()
30
+ self.eps = eps
31
+ self.weight = nn.Parameter(torch.ones(dim))
32
+
33
+ def _norm(self, x):
34
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
35
+
36
+ def forward(self, x):
37
+ output = self._norm(x.float()).type_as(x)
38
+ return output * self.weight
39
+
40
+
41
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
42
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
43
+ t = torch.arange(end, device=freqs.device) # type: ignore
44
+ freqs = torch.outer(t, freqs).float() # type: ignore
45
+ freqs_cos = torch.cos(freqs) # real part
46
+ freqs_sin = torch.sin(freqs) # imaginary part
47
+ return freqs_cos, freqs_sin
48
+
49
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
50
+ ndim = x.ndim
51
+ assert 0 <= 1 < ndim
52
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
53
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
54
+ return freqs_cis.view(shape)
55
+
56
+ def apply_rotary_emb(
57
+ xq: torch.Tensor,
58
+ xk: torch.Tensor,
59
+ freqs_cos: torch.Tensor,
60
+ freqs_sin: torch.Tensor
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+
63
+ # reshape xq and xk to match the complex representation
64
+ xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
65
+ xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
66
+
67
+ # reshape freqs_cos and freqs_sin for broadcasting
68
+ freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
69
+ freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
70
+
71
+ # apply rotation using real numbers
72
+ xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
73
+ xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
74
+ xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
75
+ xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
76
+
77
+ # flatten last two dimensions
78
+ xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
79
+ xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
80
+
81
+ return xq_out.type_as(xq), xk_out.type_as(xk)
82
+
83
+ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
84
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
85
+ bs, slen, n_kv_heads, head_dim = x.shape
86
+ if n_rep == 1:
87
+ return x
88
+ return (
89
+ x[:, :, :, None, :]
90
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
91
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
92
+ )
93
+
94
+ class Attention(nn.Module):
95
+ def __init__(self, args: ModelArgs):
96
+ super().__init__()
97
+ self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
98
+ assert args.n_heads % self.n_kv_heads == 0
99
+ model_parallel_size = 1
100
+ self.n_local_heads = args.n_heads // model_parallel_size
101
+ self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
102
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
103
+ self.head_dim = args.dim // args.n_heads
104
+ self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
105
+ self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
106
+ self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
107
+ self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
108
+ self.attn_dropout = nn.Dropout(args.dropout)
109
+ self.resid_dropout = nn.Dropout(args.dropout)
110
+ self.dropout = args.dropout
111
+
112
+ # use flash attention or a manual implementation?
113
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
114
+ if not self.flash:
115
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
116
+ mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
117
+ mask = torch.triu(mask, diagonal=1)
118
+ self.register_buffer("mask", mask)
119
+
120
+ def forward(
121
+ self,
122
+ x: torch.Tensor,
123
+ freqs_cos: torch.Tensor,
124
+ freqs_sin: torch.Tensor,
125
+ ):
126
+ bsz, seqlen, _ = x.shape
127
+
128
+ # QKV
129
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
130
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
131
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
132
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
133
+
134
+ # RoPE relative positional embeddings
135
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
136
+
137
+ # grouped multiquery attention: expand out keys and values
138
+ xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
139
+ xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
140
+
141
+ # make heads into a batch dimension
142
+ xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
143
+ xk = xk.transpose(1, 2)
144
+ xv = xv.transpose(1, 2)
145
+
146
+ # flash implementation
147
+ if self.flash:
148
+ output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
149
+ else:
150
+ # manual implementation
151
+ scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
152
+ assert hasattr(self, 'mask')
153
+ scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
154
+ scores = F.softmax(scores.float(), dim=-1).type_as(xq)
155
+ scores = self.attn_dropout(scores)
156
+ output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
157
+
158
+ # restore time as batch dimension and concat heads
159
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
160
+
161
+ # final projection into the residual stream
162
+ output = self.wo(output)
163
+ output = self.resid_dropout(output)
164
+ return output
165
+
166
+
167
+ class FeedForward(nn.Module):
168
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
169
+ super().__init__()
170
+ if hidden_dim is None:
171
+ hidden_dim = 4 * dim
172
+ hidden_dim = int(2 * hidden_dim / 3)
173
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
174
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
175
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
176
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
177
+ self.dropout = nn.Dropout(dropout)
178
+
179
+ def forward(self, x):
180
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
181
+
182
+
183
+ class TransformerBlock(nn.Module):
184
+ def __init__(self, layer_id: int, args: ModelArgs):
185
+ super().__init__()
186
+ self.n_heads = args.n_heads
187
+ self.dim = args.dim
188
+ self.head_dim = args.dim // args.n_heads
189
+ self.attention = Attention(args)
190
+ self.feed_forward = FeedForward(
191
+ dim=args.dim,
192
+ hidden_dim=args.hidden_dim,
193
+ multiple_of=args.multiple_of,
194
+ dropout=args.dropout,
195
+ )
196
+ self.layer_id = layer_id
197
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
198
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
199
+
200
+ def forward(self, x, freqs_cos, freqs_sin):
201
+ h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
202
+ out = h + self.feed_forward.forward(self.ffn_norm(h))
203
+ return out
204
+
205
+
206
+ class Transformer(nn.Module):
207
+ last_loss: Optional[torch.Tensor]
208
+
209
+ def __init__(self, params: ModelArgs):
210
+ super().__init__()
211
+ self.params = params
212
+ self.vocab_size = params.vocab_size
213
+ self.n_layers = params.n_layers
214
+
215
+ self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
216
+ self.dropout = nn.Dropout(params.dropout)
217
+ self.layers = torch.nn.ModuleList()
218
+ for layer_id in range(params.n_layers):
219
+ self.layers.append(TransformerBlock(layer_id, params))
220
+ self.norm = RMSNorm(params.dim, eps=params.norm_eps)
221
+ self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
222
+
223
+ # share the unembedding parameters with the embedding parameters
224
+ self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
225
+
226
+ # some useful precompute for the RoPE relative positional embeddings
227
+ freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
228
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
229
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
230
+
231
+ # init all weights
232
+ self.apply(self._init_weights)
233
+ # apply special scaled init to the residual projections, per GPT-2 paper
234
+ for pn, p in self.named_parameters():
235
+ if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
236
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
237
+
238
+ # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
239
+ self.last_loss = None
240
+
241
+ def _init_weights(self, module):
242
+ if isinstance(module, nn.Linear):
243
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
244
+ if module.bias is not None:
245
+ torch.nn.init.zeros_(module.bias)
246
+ elif isinstance(module, nn.Embedding):
247
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
248
+
249
+ def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
250
+ _bsz, seqlen = tokens.shape
251
+ h = self.tok_embeddings(tokens)
252
+ h = self.dropout(h)
253
+ freqs_cos = self.freqs_cos[:seqlen]
254
+ freqs_sin = self.freqs_sin[:seqlen]
255
+
256
+ for layer in self.layers:
257
+ h = layer(h, freqs_cos, freqs_sin)
258
+ h = self.norm(h)
259
+
260
+ if targets is not None:
261
+ # if we are given some desired targets also calculate the loss
262
+ logits = self.output(h)
263
+ self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
264
+ else:
265
+ # inference-time mini-optimization: only forward the output on the very last position
266
+ logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
267
+ self.last_loss = None
268
+
269
+ return logits
270
+
271
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
272
+ # start with all of the candidate parameters
273
+ param_dict = {pn: p for pn, p in self.named_parameters()}
274
+ # filter out those that do not require grad
275
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
276
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
277
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
278
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
279
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
280
+ optim_groups = [
281
+ {'params': decay_params, 'weight_decay': weight_decay},
282
+ {'params': nodecay_params, 'weight_decay': 0.0}
283
+ ]
284
+ num_decay_params = sum(p.numel() for p in decay_params)
285
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
286
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
287
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
288
+ # Create AdamW optimizer and use the fused version if it is available
289
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
290
+ use_fused = fused_available and device_type == 'cuda'
291
+ extra_args = dict(fused=True) if use_fused else dict()
292
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
293
+ print(f"using fused AdamW: {use_fused}")
294
+
295
+ return optimizer
296
+
297
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
298
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
299
+ # first estimate the number of flops we do per iteration.
300
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
301
+ N = sum(p.numel() for p in self.parameters())
302
+ cfg = self.params
303
+ L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
304
+ flops_per_token = 6*N + 12*L*H*Q*T
305
+ flops_per_fwdbwd = flops_per_token * T
306
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
307
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
308
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
309
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
310
+ mfu = flops_achieved / flops_promised
311
+ return mfu
312
+
313
+ @torch.inference_mode()
314
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
315
+ """
316
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
317
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
318
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
319
+ Also note this is a super inefficient version of sampling with no key/value cache.
320
+ """
321
+ for _ in range(max_new_tokens):
322
+ # if the sequence context is growing too long we must crop it at block_size
323
+ idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
324
+ # forward the model to get the logits for the index in the sequence
325
+ logits = self(idx_cond)
326
+ logits = logits[:, -1, :] # crop to just the final time step
327
+ if temperature == 0.0:
328
+ # "sample" the single most likely index
329
+ _, idx_next = torch.topk(logits, k=1, dim=-1)
330
+ else:
331
+ # pluck the logits at the final step and scale by desired temperature
332
+ logits = logits / temperature
333
+ # optionally crop the logits to only the top k options
334
+ if top_k is not None:
335
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
336
+ logits[logits < v[:, [-1]]] = -float('Inf')
337
+ # apply softmax to convert logits to (normalized) probabilities
338
+ probs = F.softmax(logits, dim=-1)
339
+ idx_next = torch.multinomial(probs, num_samples=1)
340
+ # append sampled index to the running sequence and continue
341
+ idx = torch.cat((idx, idx_next), dim=1)
342
+
343
+ return idx
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy==1.23.5
2
+ pytest==7.4.0
3
+ Requests==2.31.0
4
+ sentencepiece==0.1.99
5
+ torch==2.0.1
6
+ tqdm==4.64.1
7
+ wandb==0.15.5
run.c ADDED
@@ -0,0 +1,973 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Inference for Llama-2 Transformer model in pure C */
2
+
3
+ #include <stdio.h>
4
+ #include <stdlib.h>
5
+ #include <ctype.h>
6
+ #include <time.h>
7
+ #include <math.h>
8
+ #include <string.h>
9
+ #include <fcntl.h>
10
+ #if defined _WIN32
11
+ #include "win.h"
12
+ #else
13
+ #include <unistd.h>
14
+ #include <sys/mman.h>
15
+ #endif
16
+ // ----------------------------------------------------------------------------
17
+ // Transformer model
18
+
19
+ typedef struct {
20
+ int dim; // transformer dimension
21
+ int hidden_dim; // for ffn layers
22
+ int n_layers; // number of layers
23
+ int n_heads; // number of query heads
24
+ int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
25
+ int vocab_size; // vocabulary size, usually 256 (byte-level)
26
+ int seq_len; // max sequence length
27
+ } Config;
28
+
29
+ typedef struct {
30
+ // token embedding table
31
+ float* token_embedding_table; // (vocab_size, dim)
32
+ // weights for rmsnorms
33
+ float* rms_att_weight; // (layer, dim) rmsnorm weights
34
+ float* rms_ffn_weight; // (layer, dim)
35
+ // weights for matmuls. note dim == n_heads * head_size
36
+ float* wq; // (layer, dim, n_heads * head_size)
37
+ float* wk; // (layer, dim, n_kv_heads * head_size)
38
+ float* wv; // (layer, dim, n_kv_heads * head_size)
39
+ float* wo; // (layer, n_heads * head_size, dim)
40
+ // weights for ffn
41
+ float* w1; // (layer, hidden_dim, dim)
42
+ float* w2; // (layer, dim, hidden_dim)
43
+ float* w3; // (layer, hidden_dim, dim)
44
+ // final rmsnorm
45
+ float* rms_final_weight; // (dim,)
46
+ // (optional) classifier weights for the logits, on the last layer
47
+ float* wcls;
48
+ } TransformerWeights;
49
+
50
+ typedef struct {
51
+ // current wave of activations
52
+ float *x; // activation at current time stamp (dim,)
53
+ float *xb; // same, but inside a residual branch (dim,)
54
+ float *xb2; // an additional buffer just for convenience (dim,)
55
+ float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
56
+ float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
57
+ float *q; // query (dim,)
58
+ float *k; // key (dim,)
59
+ float *v; // value (dim,)
60
+ float *att; // buffer for scores/attention values (n_heads, seq_len)
61
+ float *logits; // output logits
62
+ // kv cache
63
+ float* key_cache; // (layer, seq_len, dim)
64
+ float* value_cache; // (layer, seq_len, dim)
65
+ } RunState;
66
+
67
+ typedef struct {
68
+ Config config; // the hyperparameters of the architecture (the blueprint)
69
+ TransformerWeights weights; // the weights of the model
70
+ RunState state; // buffers for the "wave" of activations in the forward pass
71
+ // some more state needed to properly clean up the memory mapping (sigh)
72
+ int fd; // file descriptor for memory mapping
73
+ float* data; // memory mapped data pointer
74
+ ssize_t file_size; // size of the checkpoint file in bytes
75
+ } Transformer;
76
+
77
+ void malloc_run_state(RunState* s, Config* p) {
78
+ // we calloc instead of malloc to keep valgrind happy
79
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
80
+ s->x = calloc(p->dim, sizeof(float));
81
+ s->xb = calloc(p->dim, sizeof(float));
82
+ s->xb2 = calloc(p->dim, sizeof(float));
83
+ s->hb = calloc(p->hidden_dim, sizeof(float));
84
+ s->hb2 = calloc(p->hidden_dim, sizeof(float));
85
+ s->q = calloc(p->dim, sizeof(float));
86
+ s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
87
+ s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
88
+ s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
89
+ s->logits = calloc(p->vocab_size, sizeof(float));
90
+ // ensure all mallocs went fine
91
+ if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
92
+ || !s->key_cache || !s->value_cache || !s->att || !s->logits) {
93
+ fprintf(stderr, "malloc failed!\n");
94
+ exit(EXIT_FAILURE);
95
+ }
96
+ }
97
+
98
+ void free_run_state(RunState* s) {
99
+ free(s->x);
100
+ free(s->xb);
101
+ free(s->xb2);
102
+ free(s->hb);
103
+ free(s->hb2);
104
+ free(s->q);
105
+ free(s->att);
106
+ free(s->logits);
107
+ free(s->key_cache);
108
+ free(s->value_cache);
109
+ }
110
+
111
+ void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
112
+ int head_size = p->dim / p->n_heads;
113
+ // make sure the multiplications below are done in 64bit to fit the parameter counts of 13B+ models
114
+ unsigned long long n_layers = p->n_layers;
115
+ w->token_embedding_table = ptr;
116
+ ptr += p->vocab_size * p->dim;
117
+ w->rms_att_weight = ptr;
118
+ ptr += n_layers * p->dim;
119
+ w->wq = ptr;
120
+ ptr += n_layers * p->dim * (p->n_heads * head_size);
121
+ w->wk = ptr;
122
+ ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
123
+ w->wv = ptr;
124
+ ptr += n_layers * p->dim * (p->n_kv_heads * head_size);
125
+ w->wo = ptr;
126
+ ptr += n_layers * (p->n_heads * head_size) * p->dim;
127
+ w->rms_ffn_weight = ptr;
128
+ ptr += n_layers * p->dim;
129
+ w->w1 = ptr;
130
+ ptr += n_layers * p->dim * p->hidden_dim;
131
+ w->w2 = ptr;
132
+ ptr += n_layers * p->hidden_dim * p->dim;
133
+ w->w3 = ptr;
134
+ ptr += n_layers * p->dim * p->hidden_dim;
135
+ w->rms_final_weight = ptr;
136
+ ptr += p->dim;
137
+ ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_real (for RoPE)
138
+ ptr += p->seq_len * head_size / 2; // skip what used to be freq_cis_imag (for RoPE)
139
+ w->wcls = shared_weights ? w->token_embedding_table : ptr;
140
+ }
141
+
142
+ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
143
+ int* fd, float** data, ssize_t* file_size) {
144
+ FILE *file = fopen(checkpoint, "rb");
145
+ if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
146
+ // read in the config header
147
+ if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
148
+ // negative vocab size is hacky way of signaling unshared weights. bit yikes.
149
+ int shared_weights = config->vocab_size > 0 ? 1 : 0;
150
+ config->vocab_size = abs(config->vocab_size);
151
+ // figure out the file size
152
+ fseek(file, 0, SEEK_END); // move file pointer to end of file
153
+ *file_size = ftell(file); // get the file size, in bytes
154
+ fclose(file);
155
+ // memory map the Transformer weights into the data pointer
156
+ *fd = open(checkpoint, O_RDONLY); // open in read only mode
157
+ if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
158
+ *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
159
+ if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
160
+ float* weights_ptr = *data + sizeof(Config)/sizeof(float);
161
+ memory_map_weights(weights, config, weights_ptr, shared_weights);
162
+ }
163
+
164
+ void build_transformer(Transformer *t, char* checkpoint_path) {
165
+ // read in the Config and the Weights from the checkpoint
166
+ read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
167
+ // allocate the RunState buffers
168
+ malloc_run_state(&t->state, &t->config);
169
+ }
170
+
171
+ void free_transformer(Transformer* t) {
172
+ // close the memory mapping
173
+ if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
174
+ if (t->fd != -1) { close(t->fd); }
175
+ // free the RunState buffers
176
+ free_run_state(&t->state);
177
+ }
178
+
179
+ // ----------------------------------------------------------------------------
180
+ // neural net blocks; the dynamics of the Transformer
181
+
182
+ void rmsnorm(float* o, float* x, float* weight, int size) {
183
+ // calculate sum of squares
184
+ float ss = 0.0f;
185
+ for (int j = 0; j < size; j++) {
186
+ ss += x[j] * x[j];
187
+ }
188
+ ss /= size;
189
+ ss += 1e-5f;
190
+ ss = 1.0f / sqrtf(ss);
191
+ // normalize and scale
192
+ for (int j = 0; j < size; j++) {
193
+ o[j] = weight[j] * (ss * x[j]);
194
+ }
195
+ }
196
+
197
+ void softmax(float* x, int size) {
198
+ // find max value (for numerical stability)
199
+ float max_val = x[0];
200
+ for (int i = 1; i < size; i++) {
201
+ if (x[i] > max_val) {
202
+ max_val = x[i];
203
+ }
204
+ }
205
+ // exp and sum
206
+ float sum = 0.0f;
207
+ for (int i = 0; i < size; i++) {
208
+ x[i] = expf(x[i] - max_val);
209
+ sum += x[i];
210
+ }
211
+ // normalize
212
+ for (int i = 0; i < size; i++) {
213
+ x[i] /= sum;
214
+ }
215
+ }
216
+
217
+ void matmul(float* xout, float* x, float* w, int n, int d) {
218
+ // W (d,n) @ x (n,) -> xout (d,)
219
+ // by far the most amount of time is spent inside this little function
220
+ int i;
221
+ #pragma omp parallel for private(i)
222
+ for (i = 0; i < d; i++) {
223
+ float val = 0.0f;
224
+ for (int j = 0; j < n; j++) {
225
+ val += w[i * n + j] * x[j];
226
+ }
227
+ xout[i] = val;
228
+ }
229
+ }
230
+
231
+ float* forward(Transformer* transformer, int token, int pos) {
232
+
233
+ // a few convenience variables
234
+ Config* p = &transformer->config;
235
+ TransformerWeights* w = &transformer->weights;
236
+ RunState* s = &transformer->state;
237
+ float *x = s->x;
238
+ int dim = p->dim;
239
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
240
+ int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
241
+ int hidden_dim = p->hidden_dim;
242
+ int head_size = dim / p->n_heads;
243
+
244
+ // copy the token embedding into x
245
+ float* content_row = w->token_embedding_table + token * dim;
246
+ memcpy(x, content_row, dim*sizeof(*x));
247
+
248
+ // forward all the layers
249
+ for(unsigned long long l = 0; l < p->n_layers; l++) {
250
+
251
+ // attention rmsnorm
252
+ rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
253
+
254
+ // key and value point to the kv cache
255
+ int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
256
+ s->k = s->key_cache + loff + pos * kv_dim;
257
+ s->v = s->value_cache + loff + pos * kv_dim;
258
+
259
+ // qkv matmuls for this position
260
+ matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
261
+ matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
262
+ matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
263
+
264
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
265
+ for (int i = 0; i < dim; i+=2) {
266
+ int head_dim = i % head_size;
267
+ float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
268
+ float val = pos * freq;
269
+ float fcr = cosf(val);
270
+ float fci = sinf(val);
271
+ int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
272
+ for (int v = 0; v < rotn; v++) {
273
+ float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
274
+ float v0 = vec[i];
275
+ float v1 = vec[i+1];
276
+ vec[i] = v0 * fcr - v1 * fci;
277
+ vec[i+1] = v0 * fci + v1 * fcr;
278
+ }
279
+ }
280
+
281
+ // multihead attention. iterate over all heads
282
+ int h;
283
+ #pragma omp parallel for private(h)
284
+ for (h = 0; h < p->n_heads; h++) {
285
+ // get the query vector for this head
286
+ float* q = s->q + h * head_size;
287
+ // attention scores for this head
288
+ float* att = s->att + h * p->seq_len;
289
+ // iterate over all timesteps, including the current one
290
+ for (int t = 0; t <= pos; t++) {
291
+ // get the key vector for this head and at this timestep
292
+ float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
293
+ // calculate the attention score as the dot product of q and k
294
+ float score = 0.0f;
295
+ for (int i = 0; i < head_size; i++) {
296
+ score += q[i] * k[i];
297
+ }
298
+ score /= sqrtf(head_size);
299
+ // save the score to the attention buffer
300
+ att[t] = score;
301
+ }
302
+
303
+ // softmax the scores to get attention weights, from 0..pos inclusively
304
+ softmax(att, pos + 1);
305
+
306
+ // weighted sum of the values, store back into xb
307
+ float* xb = s->xb + h * head_size;
308
+ memset(xb, 0, head_size * sizeof(float));
309
+ for (int t = 0; t <= pos; t++) {
310
+ // get the value vector for this head and at this timestep
311
+ float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
312
+ // get the attention weight for this timestep
313
+ float a = att[t];
314
+ // accumulate the weighted value into xb
315
+ for (int i = 0; i < head_size; i++) {
316
+ xb[i] += a * v[i];
317
+ }
318
+ }
319
+ }
320
+
321
+ // final matmul to get the output of the attention
322
+ matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
323
+
324
+ // residual connection back into x
325
+ for (int i = 0; i < dim; i++) {
326
+ x[i] += s->xb2[i];
327
+ }
328
+
329
+ // ffn rmsnorm
330
+ rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
331
+
332
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
333
+ // first calculate self.w1(x) and self.w3(x)
334
+ matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
335
+ matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
336
+
337
+ // SwiGLU non-linearity
338
+ for (int i = 0; i < hidden_dim; i++) {
339
+ float val = s->hb[i];
340
+ // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
341
+ val *= (1.0f / (1.0f + expf(-val)));
342
+ // elementwise multiply with w3(x)
343
+ val *= s->hb2[i];
344
+ s->hb[i] = val;
345
+ }
346
+
347
+ // final matmul to get the output of the ffn
348
+ matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
349
+
350
+ // residual connection
351
+ for (int i = 0; i < dim; i++) {
352
+ x[i] += s->xb[i];
353
+ }
354
+ }
355
+
356
+ // final rmsnorm
357
+ rmsnorm(x, x, w->rms_final_weight, dim);
358
+
359
+ // classifier into logits
360
+ matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
361
+ return s->logits;
362
+ }
363
+
364
+ // ----------------------------------------------------------------------------
365
+ // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
366
+
367
+ typedef struct {
368
+ char *str;
369
+ int id;
370
+ } TokenIndex;
371
+
372
+ typedef struct {
373
+ char** vocab;
374
+ float* vocab_scores;
375
+ TokenIndex *sorted_vocab;
376
+ int vocab_size;
377
+ unsigned int max_token_length;
378
+ unsigned char byte_pieces[512]; // stores all single-byte strings
379
+ } Tokenizer;
380
+
381
+ int compare_tokens(const void *a, const void *b) {
382
+ return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
383
+ }
384
+
385
+ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
386
+ // i should have written the vocab_size into the tokenizer file... sigh
387
+ t->vocab_size = vocab_size;
388
+ // malloc space to hold the scores and the strings
389
+ t->vocab = (char**)malloc(vocab_size * sizeof(char*));
390
+ t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
391
+ t->sorted_vocab = NULL; // initialized lazily
392
+ for (int i = 0; i < 256; i++) {
393
+ t->byte_pieces[i * 2] = (unsigned char)i;
394
+ t->byte_pieces[i * 2 + 1] = '\0';
395
+ }
396
+ // read in the file
397
+ FILE *file = fopen(tokenizer_path, "rb");
398
+ if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
399
+ if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
400
+ int len;
401
+ for (int i = 0; i < vocab_size; i++) {
402
+ if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
403
+ if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
404
+ t->vocab[i] = (char *)malloc(len + 1);
405
+ if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
406
+ t->vocab[i][len] = '\0'; // add the string terminating token
407
+ }
408
+ fclose(file);
409
+ }
410
+
411
+ void free_tokenizer(Tokenizer* t) {
412
+ for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
413
+ free(t->vocab);
414
+ free(t->vocab_scores);
415
+ free(t->sorted_vocab);
416
+ }
417
+
418
+ char* decode(Tokenizer* t, int prev_token, int token) {
419
+ char *piece = t->vocab[token];
420
+ // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
421
+ if (prev_token == 1 && piece[0] == ' ') { piece++; }
422
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
423
+ // parse this and convert and return the actual byte
424
+ unsigned char byte_val;
425
+ if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
426
+ piece = (char*)t->byte_pieces + byte_val * 2;
427
+ }
428
+ return piece;
429
+ }
430
+
431
+ void safe_printf(char *piece) {
432
+ // piece might be a raw byte token, and we only want to print printable chars or whitespace
433
+ // because some of the other bytes can be various control codes, backspace, etc.
434
+ if (piece == NULL) { return; }
435
+ if (piece[0] == '\0') { return; }
436
+ if (piece[1] == '\0') {
437
+ unsigned char byte_val = piece[0];
438
+ if (!(isprint(byte_val) || isspace(byte_val))) {
439
+ return; // bad byte, don't print it
440
+ }
441
+ }
442
+ printf("%s", piece);
443
+ }
444
+
445
+ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
446
+ // efficiently find the perfect match for str in vocab, return its index or -1 if not found
447
+ TokenIndex tok = { .str = str }; // acts as the key to search for
448
+ TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
449
+ return res != NULL ? res->id : -1;
450
+ }
451
+
452
+ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
453
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
454
+ // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
455
+ if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
456
+
457
+ if (t->sorted_vocab == NULL) {
458
+ // lazily malloc and sort the vocabulary
459
+ t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
460
+ for (int i = 0; i < t->vocab_size; i++) {
461
+ t->sorted_vocab[i].str = t->vocab[i];
462
+ t->sorted_vocab[i].id = i;
463
+ }
464
+ qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
465
+ }
466
+
467
+ // create a temporary buffer that will store merge candidates of always two consecutive tokens
468
+ // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
469
+ char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
470
+ size_t str_len = 0;
471
+
472
+ // start at 0 tokens
473
+ *n_tokens = 0;
474
+
475
+ // add optional BOS (=1) token, if desired
476
+ if (bos) tokens[(*n_tokens)++] = 1;
477
+
478
+ // add_dummy_prefix is true by default
479
+ // so prepend a dummy prefix token to the input string, but only if text != ""
480
+ // TODO: pretty sure this isn't correct in the general case but I don't have the
481
+ // energy to read more of the sentencepiece code to figure out what it's doing
482
+ if (text[0] != '\0') {
483
+ int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
484
+ tokens[(*n_tokens)++] = dummy_prefix;
485
+ }
486
+
487
+ // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
488
+ // Code point ↔ UTF-8 conversion
489
+ // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
490
+ // U+0000 U+007F 0xxxxxxx
491
+ // U+0080 U+07FF 110xxxxx 10xxxxxx
492
+ // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
493
+ // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
494
+
495
+ // process the raw (UTF-8) byte sequence of the input string
496
+ for (char *c = text; *c != '\0'; c++) {
497
+
498
+ // reset buffer if the current byte is ASCII or a leading byte
499
+ // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
500
+ // 0x80 is 10000000
501
+ // in UTF-8, all continuation bytes start with "10" in first two bits
502
+ // so in English this is: "if this byte is not a continuation byte"
503
+ if ((*c & 0xC0) != 0x80) {
504
+ // this byte must be either a leading byte (11...) or an ASCII char (0x...)
505
+ // => reset our location, as we're starting a new UTF-8 codepoint
506
+ str_len = 0;
507
+ }
508
+
509
+ // append the current byte to the buffer
510
+ str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
511
+ str_buffer[str_len] = '\0';
512
+
513
+ // while the next character is a continuation byte, continue appending
514
+ // but if there are too many of them, just stop to avoid overruning str_buffer size.
515
+ if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
516
+ continue;
517
+ }
518
+
519
+ // ok c+1 is not a continuation byte, so we've read in a full codepoint
520
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
521
+
522
+ if (id != -1) {
523
+ // we found this codepoint in vocab, add it as a token
524
+ tokens[(*n_tokens)++] = id;
525
+ } else {
526
+ // byte_fallback encoding: just encode each byte as a token
527
+ // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
528
+ // so the individual bytes only start at index 3
529
+ for (int i=0; i < str_len; i++) {
530
+ tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
531
+ }
532
+ }
533
+ str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
534
+ }
535
+
536
+ // merge the best consecutive pair each iteration, according the scores in vocab_scores
537
+ while (1) {
538
+ float best_score = -1e10;
539
+ int best_id = -1;
540
+ int best_idx = -1;
541
+
542
+ for (int i=0; i < (*n_tokens-1); i++) {
543
+ // check if we can merge the pair (tokens[i], tokens[i+1])
544
+ sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
545
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
546
+ if (id != -1 && t->vocab_scores[id] > best_score) {
547
+ // this merge pair exists in vocab! record its score and position
548
+ best_score = t->vocab_scores[id];
549
+ best_id = id;
550
+ best_idx = i;
551
+ }
552
+ }
553
+
554
+ if (best_idx == -1) {
555
+ break; // we couldn't find any more pairs to merge, so we're done
556
+ }
557
+
558
+ // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
559
+ tokens[best_idx] = best_id;
560
+ // delete token at position best_idx+1, shift the entire sequence back 1
561
+ for (int i = best_idx+1; i < (*n_tokens-1); i++) {
562
+ tokens[i] = tokens[i+1];
563
+ }
564
+ (*n_tokens)--; // token length decreased
565
+ }
566
+
567
+ // add optional EOS (=2) token, if desired
568
+ if (eos) tokens[(*n_tokens)++] = 2;
569
+
570
+ free(str_buffer);
571
+ }
572
+
573
+ // ----------------------------------------------------------------------------
574
+ // The Sampler, which takes logits and returns a sampled token
575
+ // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
576
+
577
+ typedef struct {
578
+ float prob;
579
+ int index;
580
+ } ProbIndex; // struct used when sorting probabilities during top-p sampling
581
+
582
+ typedef struct {
583
+ int vocab_size;
584
+ ProbIndex* probindex; // buffer used in top-p sampling
585
+ float temperature;
586
+ float topp;
587
+ unsigned long long rng_state;
588
+ } Sampler;
589
+
590
+ int sample_argmax(float* probabilities, int n) {
591
+ // return the index that has the highest probability
592
+ int max_i = 0;
593
+ float max_p = probabilities[0];
594
+ for (int i = 1; i < n; i++) {
595
+ if (probabilities[i] > max_p) {
596
+ max_i = i;
597
+ max_p = probabilities[i];
598
+ }
599
+ }
600
+ return max_i;
601
+ }
602
+
603
+ int sample_mult(float* probabilities, int n, float coin) {
604
+ // sample index from probabilities (they must sum to 1!)
605
+ // coin is a random number in [0, 1), usually from random_f32()
606
+ float cdf = 0.0f;
607
+ for (int i = 0; i < n; i++) {
608
+ cdf += probabilities[i];
609
+ if (coin < cdf) {
610
+ return i;
611
+ }
612
+ }
613
+ return n - 1; // in case of rounding errors
614
+ }
615
+
616
+ int compare(const void* a, const void* b) {
617
+ ProbIndex* a_ = (ProbIndex*) a;
618
+ ProbIndex* b_ = (ProbIndex*) b;
619
+ if (a_->prob > b_->prob) return -1;
620
+ if (a_->prob < b_->prob) return 1;
621
+ return 0;
622
+ }
623
+
624
+ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
625
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
626
+ // tokens that exceed probability topp. This way we never sample tokens that
627
+ // have very low probabilities and are less likely to go "off the rails".
628
+ // coin is a random number in [0, 1), usually from random_f32()
629
+
630
+ int n0 = 0;
631
+ // quicksort indices in descending order of probabilities
632
+ // values smaller than (1 - topp) / (n - 1) cannot be part of the result
633
+ // so for efficiency we crop these out as candidates before sorting
634
+ const float cutoff = (1.0f - topp) / (n - 1);
635
+ for (int i = 0; i < n; i++) {
636
+ if (probabilities[i] >= cutoff) {
637
+ probindex[n0].index = i;
638
+ probindex[n0].prob = probabilities[i];
639
+ n0++;
640
+ }
641
+ }
642
+ qsort(probindex, n0, sizeof(ProbIndex), compare);
643
+
644
+ // truncate the list where cumulative probability exceeds topp
645
+ float cumulative_prob = 0.0f;
646
+ int last_idx = n0 - 1; // in case of rounding errors consider all elements
647
+ for (int i = 0; i < n0; i++) {
648
+ cumulative_prob += probindex[i].prob;
649
+ if (cumulative_prob > topp) {
650
+ last_idx = i;
651
+ break; // we've exceeded topp by including last_idx
652
+ }
653
+ }
654
+
655
+ // sample from the truncated list
656
+ float r = coin * cumulative_prob;
657
+ float cdf = 0.0f;
658
+ for (int i = 0; i <= last_idx; i++) {
659
+ cdf += probindex[i].prob;
660
+ if (r < cdf) {
661
+ return probindex[i].index;
662
+ }
663
+ }
664
+ return probindex[last_idx].index; // in case of rounding errors
665
+ }
666
+
667
+ void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
668
+ sampler->vocab_size = vocab_size;
669
+ sampler->temperature = temperature;
670
+ sampler->topp = topp;
671
+ sampler->rng_state = rng_seed;
672
+ // buffer only used with nucleus sampling; may not need but it's ~small
673
+ sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
674
+ }
675
+
676
+ void free_sampler(Sampler* sampler) {
677
+ free(sampler->probindex);
678
+ }
679
+
680
+ unsigned int random_u32(unsigned long long *state) {
681
+ // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
682
+ *state ^= *state >> 12;
683
+ *state ^= *state << 25;
684
+ *state ^= *state >> 27;
685
+ return (*state * 0x2545F4914F6CDD1Dull) >> 32;
686
+ }
687
+ float random_f32(unsigned long long *state) { // random float32 in [0,1)
688
+ return (random_u32(state) >> 8) / 16777216.0f;
689
+ }
690
+
691
+ int sample(Sampler* sampler, float* logits) {
692
+ // sample the token given the logits and some hyperparameters
693
+ int next;
694
+ if (sampler->temperature == 0.0f) {
695
+ // greedy argmax sampling: take the token with the highest probability
696
+ next = sample_argmax(logits, sampler->vocab_size);
697
+ } else {
698
+ // apply the temperature to the logits
699
+ for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
700
+ // apply softmax to the logits to get the probabilities for next token
701
+ softmax(logits, sampler->vocab_size);
702
+ // flip a (float) coin (this is our source of entropy for sampling)
703
+ float coin = random_f32(&sampler->rng_state);
704
+ // we sample from this distribution to get the next token
705
+ if (sampler->topp <= 0 || sampler->topp >= 1) {
706
+ // simply sample from the predicted probability distribution
707
+ next = sample_mult(logits, sampler->vocab_size, coin);
708
+ } else {
709
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
710
+ next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
711
+ }
712
+ }
713
+ return next;
714
+ }
715
+
716
+ // ----------------------------------------------------------------------------
717
+ // utilities: time
718
+
719
+ long time_in_ms() {
720
+ // return time in milliseconds, for benchmarking the model speed
721
+ struct timespec time;
722
+ clock_gettime(CLOCK_REALTIME, &time);
723
+ return time.tv_sec * 1000 + time.tv_nsec / 1000000;
724
+ }
725
+
726
+ // ----------------------------------------------------------------------------
727
+ // generation loop
728
+
729
+ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
730
+ char *empty_prompt = "";
731
+ if (prompt == NULL) { prompt = empty_prompt; }
732
+
733
+ // encode the (string) prompt into tokens sequence
734
+ int num_prompt_tokens = 0;
735
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
736
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
737
+ if (num_prompt_tokens < 1) {
738
+ fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
739
+ exit(EXIT_FAILURE);
740
+ }
741
+
742
+ // start the main loop
743
+ long start = 0; // used to time our code, only initialized after first iteration
744
+ int next; // will store the next token in the sequence
745
+ int token = prompt_tokens[0]; // kick off with the first token in the prompt
746
+ int pos = 0; // position in the sequence
747
+ while (pos < steps) {
748
+
749
+ // forward the transformer to get logits for the next token
750
+ float* logits = forward(transformer, token, pos);
751
+
752
+ // advance the state machine
753
+ if (pos < num_prompt_tokens - 1) {
754
+ // if we are still processing the input prompt, force the next prompt token
755
+ next = prompt_tokens[pos + 1];
756
+ } else {
757
+ // otherwise sample the next token from the logits
758
+ next = sample(sampler, logits);
759
+ }
760
+ pos++;
761
+
762
+ // data-dependent terminating condition: the BOS (=1) token delimits sequences
763
+ if (next == 1) { break; }
764
+
765
+ // print the token as string, decode it with the Tokenizer object
766
+ char* piece = decode(tokenizer, token, next);
767
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
768
+ fflush(stdout);
769
+ token = next;
770
+
771
+ // init the timer here because the first iteration can be slower
772
+ if (start == 0) { start = time_in_ms(); }
773
+ }
774
+ printf("\n");
775
+
776
+ // report achieved tok/s (pos-1 because the timer starts after first iteration)
777
+ if (pos > 1) {
778
+ long end = time_in_ms();
779
+ fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
780
+ }
781
+
782
+ free(prompt_tokens);
783
+ }
784
+
785
+ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
786
+ // read a line from stdin, up to but not including \n
787
+ printf("%s", guide);
788
+ if (fgets(buffer, bufsize, stdin) != NULL) {
789
+ size_t len = strlen(buffer);
790
+ if (len > 0 && buffer[len - 1] == '\n') {
791
+ buffer[len - 1] = '\0'; // strip newline
792
+ }
793
+ }
794
+ }
795
+
796
+ // ----------------------------------------------------------------------------
797
+ // chat loop
798
+ // I manually inspected the tokens for a few chat conversations compared to
799
+ // python reference and that seemed ok, but this was not thoroughly tested and
800
+ // is not safely implemented, it's more a proof of concept atm.
801
+
802
+ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
803
+ char *cli_user_prompt, char *cli_system_prompt, int steps) {
804
+
805
+ // buffers for reading the system prompt and user prompt from stdin
806
+ // you'll notice they are soomewhat haphazardly and unsafely set atm
807
+ char system_prompt[512];
808
+ char user_prompt[512];
809
+ char rendered_prompt[1152];
810
+ int num_prompt_tokens = 0;
811
+ int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
812
+ int user_idx;
813
+
814
+ // start the main loop
815
+ int8_t user_turn = 1; // user starts
816
+ int next; // will store the next token in the sequence
817
+ int token; // stores the current token to feed into the transformer
818
+ int prev_token;
819
+ int pos = 0; // position in the sequence
820
+ while (pos < steps) {
821
+
822
+ // when it is the user's turn to contribute tokens to the dialog...
823
+ if (user_turn) {
824
+ // get the (optional) system prompt at position 0
825
+ if (pos == 0) {
826
+ // at position 0, the user can also contribute a system prompt
827
+ if (cli_system_prompt == NULL) {
828
+ // system prompt was not passed in, attempt to get it from stdin
829
+ read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
830
+ } else {
831
+ // system prompt was passed in, use it
832
+ strcpy(system_prompt, cli_system_prompt);
833
+ }
834
+ }
835
+ // get the user prompt
836
+ if (pos == 0 && cli_user_prompt != NULL) {
837
+ // user prompt for position 0 was passed in, use it
838
+ strcpy(user_prompt, cli_user_prompt);
839
+ } else {
840
+ // otherwise get user prompt from stdin
841
+ read_stdin("User: ", user_prompt, sizeof(user_prompt));
842
+ }
843
+ // render user/system prompts into the Llama 2 Chat schema
844
+ if (pos == 0 && system_prompt[0] != '\0') {
845
+ char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
846
+ sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
847
+ } else {
848
+ char user_template[] = "[INST] %s [/INST]";
849
+ sprintf(rendered_prompt, user_template, user_prompt);
850
+ }
851
+ // encode the rendered prompt into tokens
852
+ encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
853
+ user_idx = 0; // reset the user index
854
+ user_turn = 0;
855
+ printf("Assistant: ");
856
+ }
857
+
858
+ // determine the token to pass into the transformer next
859
+ if (user_idx < num_prompt_tokens) {
860
+ // if we are still processing the input prompt, force the next prompt token
861
+ token = prompt_tokens[user_idx++];
862
+ } else {
863
+ // otherwise use the next token sampled from previous turn
864
+ token = next;
865
+ }
866
+ // EOS (=2) token ends the Assistant turn
867
+ if (token == 2) { user_turn = 1; }
868
+
869
+ // forward the transformer to get logits for the next token
870
+ float* logits = forward(transformer, token, pos);
871
+ next = sample(sampler, logits);
872
+ pos++;
873
+
874
+ if (user_idx >= num_prompt_tokens && next != 2) {
875
+ // the Assistant is responding, so print its output
876
+ char* piece = decode(tokenizer, token, next);
877
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
878
+ fflush(stdout);
879
+ }
880
+ if (next == 2) { printf("\n"); }
881
+ }
882
+ printf("\n");
883
+ free(prompt_tokens);
884
+ }
885
+
886
+
887
+ // ----------------------------------------------------------------------------
888
+ // CLI, include only if not testing
889
+ #ifndef TESTING
890
+
891
+ void error_usage() {
892
+ fprintf(stderr, "Usage: run <checkpoint> [options]\n");
893
+ fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
894
+ fprintf(stderr, "Options:\n");
895
+ fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
896
+ fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
897
+ fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
898
+ fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
899
+ fprintf(stderr, " -i <string> input prompt\n");
900
+ fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
901
+ fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
902
+ fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
903
+ exit(EXIT_FAILURE);
904
+ }
905
+
906
+ int main(int argc, char *argv[]) {
907
+
908
+ // default parameters
909
+ char *checkpoint_path = NULL; // e.g. out/model.bin
910
+ char *tokenizer_path = "tokenizer.bin";
911
+ float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
912
+ float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
913
+ int steps = 256; // number of steps to run for
914
+ char *prompt = NULL; // prompt string
915
+ unsigned long long rng_seed = 0; // seed rng with time by default
916
+ char *mode = "generate"; // generate|chat
917
+ char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
918
+
919
+ // poor man's C argparse so we can override the defaults above from the command line
920
+ if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
921
+ for (int i = 2; i < argc; i+=2) {
922
+ // do some basic validation
923
+ if (i + 1 >= argc) { error_usage(); } // must have arg after flag
924
+ if (argv[i][0] != '-') { error_usage(); } // must start with dash
925
+ if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
926
+ // read in the args
927
+ if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
928
+ else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
929
+ else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
930
+ else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
931
+ else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
932
+ else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
933
+ else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
934
+ else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
935
+ else { error_usage(); }
936
+ }
937
+
938
+ // parameter validation/overrides
939
+ if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
940
+ if (temperature < 0.0) temperature = 0.0;
941
+ if (topp < 0.0 || 1.0 < topp) topp = 0.9;
942
+ if (steps < 0) steps = 0;
943
+
944
+ // build the Transformer via the model .bin file
945
+ Transformer transformer;
946
+ build_transformer(&transformer, checkpoint_path);
947
+ if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // ovrerride to ~max length
948
+
949
+ // build the Tokenizer via the tokenizer .bin file
950
+ Tokenizer tokenizer;
951
+ build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
952
+
953
+ // build the Sampler
954
+ Sampler sampler;
955
+ build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
956
+
957
+ // run!
958
+ if (strcmp(mode, "generate") == 0) {
959
+ generate(&transformer, &tokenizer, &sampler, prompt, steps);
960
+ } else if (strcmp(mode, "chat") == 0) {
961
+ chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
962
+ } else {
963
+ fprintf(stderr, "unknown mode: %s\n", mode);
964
+ error_usage();
965
+ }
966
+
967
+ // memory and file handles cleanup
968
+ free_sampler(&sampler);
969
+ free_tokenizer(&tokenizer);
970
+ free_transformer(&transformer);
971
+ return 0;
972
+ }
973
+ #endif
run.ipynb ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "HLdoj4cz-xal"
7
+ },
8
+ "source": [
9
+ "# Run.c\n",
10
+ "\n",
11
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/karpathy/llama2.c/blob/master/run.ipynb)\n",
12
+ "\n",
13
+ "More details can be found in the [README.md](README.md) ."
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {
20
+ "id": "Une3Ozlnu1B7"
21
+ },
22
+ "outputs": [],
23
+ "source": [
24
+ "#@title Clone Project\n",
25
+ "\n",
26
+ "!git clone https://github.com/karpathy/llama2.c.git\n",
27
+ "%cd llama2.c"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "#@title Build\n",
37
+ "\n",
38
+ "!make runfast"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "id": "thm0ZBrtSgoC"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "#@title Pick Your Model\n",
50
+ "\n",
51
+ "#@markdown Choose model\n",
52
+ "model = \"stories15M\" #@param [\"stories15M\", \"stories42M\", \"stories110M\"]\n",
53
+ "\n",
54
+ "download_url = \"\"\n",
55
+ "\n",
56
+ "if(model == \"stories15M\"):\n",
57
+ " download_url = \"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin\"\n",
58
+ "if(model == \"stories42M\"):\n",
59
+ " download_url = \"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin\"\n",
60
+ "if(model == \"stories110M\"):\n",
61
+ " download_url = \"https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin\"\n",
62
+ "\n",
63
+ "print(f\"download_url: {download_url}\")\n",
64
+ "\n",
65
+ "!wget $download_url\n",
66
+ "\n",
67
+ "model_file = model + \".bin\""
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "metadata": {
74
+ "id": "OgAc3KjuT-NM"
75
+ },
76
+ "outputs": [],
77
+ "source": [
78
+ "#@title Generate Stories\n",
79
+ "\n",
80
+ "# Generate args\n",
81
+ "max_token = 256 #@param {type:\"slider\", min:32, max:1024, step:32}\n",
82
+ "temperature = 0.8 #@param {type:\"slider\", min:0.0, max:1, step:0.05}\n",
83
+ "top_p = 0.9 #@param {type:\"slider\", min:0.0, max:1.0, step:0.05}\n",
84
+ "prompt = \"One day, Lily met a Shoggoth\" #@param {type:\"string\"}\n",
85
+ "\n",
86
+ "print(f\"model: {model_file}, max_token: {max_token}, temperature: {temperature}, top_p: {top_p}, prompt: {prompt}\")\n",
87
+ "print(f\"----------------------------\\n\")\n",
88
+ "\n",
89
+ "cmd = f'./run {model_file} -t {temperature} -p {top_p} -n {max_token} -i \"{prompt}\"'\n",
90
+ "!{cmd}"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "#@title Run Meta's Llama 2 models\n",
100
+ "\n",
101
+ "#@markdown input your huggingface [access token](https://huggingface.co/settings/tokens) to download Meta's Llama 2 models.\n",
102
+ "\n",
103
+ "from huggingface_hub import snapshot_download\n",
104
+ "\n",
105
+ "token = \"replace your huggingface access token\" #@param {type:\"string\"}\n",
106
+ "path = snapshot_download(repo_id=\"meta-llama/Llama-2-7b\",cache_dir=\"Llama-2-7b\", use_auth_token=token)\n",
107
+ "\n",
108
+ "!python export_meta_llama_bin.py $path llama2_7b.bin\n",
109
+ "\n",
110
+ "print(\"./run llama2_7b.bin\\n\")\n",
111
+ "!./run llama2_7b.bin"
112
+ ]
113
+ }
114
+ ],
115
+ "metadata": {
116
+ "colab": {
117
+ "private_outputs": true,
118
+ "provenance": []
119
+ },
120
+ "kernelspec": {
121
+ "display_name": "Python 3",
122
+ "name": "python3"
123
+ },
124
+ "language_info": {
125
+ "name": "python"
126
+ }
127
+ },
128
+ "nbformat": 4,
129
+ "nbformat_minor": 0
130
+ }
runq.c ADDED
@@ -0,0 +1,1092 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Inference for Llama-2 Transformer model in pure C, int8 quantized forward pass. */
2
+
3
+ #include <stdio.h>
4
+ #include <stdlib.h>
5
+ #include <ctype.h>
6
+ #include <stdint.h>
7
+ #include <time.h>
8
+ #include <math.h>
9
+ #include <string.h>
10
+ #include <fcntl.h>
11
+ #if defined _WIN32
12
+ #include "win.h"
13
+ #else
14
+ #include <unistd.h>
15
+ #include <sys/mman.h>
16
+ #endif
17
+ // ----------------------------------------------------------------------------
18
+ // Globals
19
+ int GS = 0; // group size global for quantization of the weights
20
+
21
+ // ----------------------------------------------------------------------------
22
+ // Transformer model
23
+
24
+ typedef struct {
25
+ int dim; // transformer dimension
26
+ int hidden_dim; // for ffn layers
27
+ int n_layers; // number of layers
28
+ int n_heads; // number of query heads
29
+ int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
30
+ int vocab_size; // vocabulary size, usually 256 (byte-level)
31
+ int seq_len; // max sequence length
32
+ } Config;
33
+
34
+ typedef struct {
35
+ int8_t* q; // quantized values
36
+ float* s; // scaling factors
37
+ } QuantizedTensor;
38
+
39
+ typedef struct {
40
+ // token embedding table
41
+ QuantizedTensor *q_tokens; // (vocab_size, dim)
42
+ float* token_embedding_table; // same, but dequantized
43
+
44
+ // weights for rmsnorms
45
+ float* rms_att_weight; // (layer, dim) rmsnorm weights
46
+ float* rms_ffn_weight; // (layer, dim)
47
+ // weights for matmuls. note dim == n_heads * head_size
48
+ QuantizedTensor *wq; // (layer, dim, n_heads * head_size)
49
+ QuantizedTensor *wk; // (layer, dim, n_kv_heads * head_size)
50
+ QuantizedTensor *wv; // (layer, dim, n_kv_heads * head_size)
51
+ QuantizedTensor *wo; // (layer, n_heads * head_size, dim)
52
+ // weights for ffn
53
+ QuantizedTensor *w1; // (layer, hidden_dim, dim)
54
+ QuantizedTensor *w2; // (layer, dim, hidden_dim)
55
+ QuantizedTensor *w3; // (layer, hidden_dim, dim)
56
+ // final rmsnorm
57
+ float* rms_final_weight; // (dim,)
58
+ // (optional) classifier weights for the logits, on the last layer
59
+ QuantizedTensor *wcls;
60
+ } TransformerWeights;
61
+
62
+ typedef struct {
63
+ // current wave of activations
64
+ float *x; // activation at current time stamp (dim,)
65
+ float *xb; // same, but inside a residual branch (dim,)
66
+ float *xb2; // an additional buffer just for convenience (dim,)
67
+ float *hb; // buffer for hidden dimension in the ffn (hidden_dim,)
68
+ float *hb2; // buffer for hidden dimension in the ffn (hidden_dim,)
69
+ QuantizedTensor xq; // quantized x (dim,)
70
+ QuantizedTensor hq; // quantized hb (hidden_dim,)
71
+ float *q; // query (dim,)
72
+ float *k; // key (dim,)
73
+ float *v; // value (dim,)
74
+ float *att; // buffer for scores/attention values (n_heads, seq_len)
75
+ float *logits; // output logits
76
+ // kv cache
77
+ float* key_cache; // (layer, seq_len, dim)
78
+ float* value_cache; // (layer, seq_len, dim)
79
+ } RunState;
80
+
81
+ typedef struct {
82
+ Config config; // the hyperparameters of the architecture (the blueprint)
83
+ TransformerWeights weights; // the weights of the model
84
+ RunState state; // buffers for the "wave" of activations in the forward pass
85
+ // some more state needed to properly clean up the memory mapping (sigh)
86
+ int fd; // file descriptor for memory mapping
87
+ float* data; // memory mapped data pointer
88
+ ssize_t file_size; // size of the checkpoint file in bytes
89
+ } Transformer;
90
+
91
+ void malloc_run_state(RunState* s, Config* p) {
92
+ // we calloc instead of malloc to keep valgrind happy
93
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
94
+ s->x = calloc(p->dim, sizeof(float));
95
+ s->xb = calloc(p->dim, sizeof(float));
96
+ s->xb2 = calloc(p->dim, sizeof(float));
97
+ s->hb = calloc(p->hidden_dim, sizeof(float));
98
+ s->hb2 = calloc(p->hidden_dim, sizeof(float));
99
+ s->xq = (QuantizedTensor) { .q = calloc(p->dim, sizeof(int8_t)), .s = calloc(p->dim, sizeof(float)) };
100
+ s->hq = (QuantizedTensor) { .q = calloc(p->hidden_dim, sizeof(int8_t)), .s = calloc(p->hidden_dim, sizeof(float)) };
101
+ s->q = calloc(p->dim, sizeof(float));
102
+ s->k = calloc(kv_dim, sizeof(float));
103
+ s->v = calloc(kv_dim, sizeof(float));
104
+ s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
105
+ s->logits = calloc(p->vocab_size, sizeof(float));
106
+ s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
107
+ s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
108
+ // ensure all mallocs went fine
109
+ if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
110
+ || !s->k || !s->v || !s->att || !s->logits || !s->key_cache
111
+ || !s->value_cache) {
112
+ fprintf(stderr, "malloc failed!\n");
113
+ exit(EXIT_FAILURE);
114
+ }
115
+ }
116
+
117
+ void free_run_state(RunState* s) {
118
+ free(s->x);
119
+ free(s->xb);
120
+ free(s->xb2);
121
+ free(s->hb);
122
+ free(s->hb2);
123
+ free(s->xq.q);
124
+ free(s->xq.s);
125
+ free(s->hq.q);
126
+ free(s->hq.s);
127
+ free(s->q);
128
+ free(s->k);
129
+ free(s->v);
130
+ free(s->att);
131
+ free(s->logits);
132
+ free(s->key_cache);
133
+ free(s->value_cache);
134
+ }
135
+
136
+ // ----------------------------------------------------------------------------
137
+ // Quantization functions
138
+
139
+ void dequantize(QuantizedTensor *qx, float* x, int n) {
140
+ for (int i = 0; i < n; i++) {
141
+ x[i] = qx->q[i] * qx->s[i / GS];
142
+ }
143
+ }
144
+
145
+ void quantize(QuantizedTensor *qx, float* x, int n) {
146
+ int num_groups = n / GS;
147
+ float Q_MAX = 127.0f;
148
+
149
+ for (int group = 0; group < num_groups; group++) {
150
+
151
+ // find the max absolute value in the current group
152
+ float wmax = 0.0;
153
+ for (int i = 0; i < GS; i++) {
154
+ float val = fabs(x[group * GS + i]);
155
+ if (val > wmax) {
156
+ wmax = val;
157
+ }
158
+ }
159
+
160
+ // calculate and write the scaling factor
161
+ float scale = wmax / Q_MAX;
162
+ qx->s[group] = scale;
163
+
164
+ // calculate and write the quantized values
165
+ for (int i = 0; i < GS; i++) {
166
+ float quant_value = x[group * GS + i] / scale; // scale
167
+ int8_t quantized = (int8_t) round(quant_value); // round and clamp
168
+ qx->q[group * GS + i] = quantized;
169
+ }
170
+ }
171
+ }
172
+
173
+ /* initialize `n` x quantized tensor (with `size_each` elements), starting from memory pointed at *ptr */
174
+ QuantizedTensor *init_quantized_tensors(void **ptr, int n, int size_each) {
175
+ void *p = *ptr;
176
+ QuantizedTensor *res = malloc(n * sizeof(QuantizedTensor));
177
+ for(int i=0; i<n; i++) {
178
+ /* map quantized int8 values*/
179
+ res[i].q = (int8_t*)p;
180
+ p = (int8_t*)p + size_each;
181
+ /* map scale factors */
182
+ res[i].s = (float*)p;
183
+ p = (float*)p + size_each / GS;
184
+ }
185
+ *ptr = p; // advance ptr to current position
186
+ return res;
187
+ }
188
+
189
+ void memory_map_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t shared_classifier) {
190
+ int head_size = p->dim / p->n_heads;
191
+ // first are the parameters that are kept in fp32 (the rmsnorm (1D) weights)
192
+ float* fptr = (float*) ptr; // cast our pointer to float*
193
+ w->rms_att_weight = fptr;
194
+ fptr += p->n_layers * p->dim;
195
+ w->rms_ffn_weight = fptr;
196
+ fptr += p->n_layers * p->dim;
197
+ w->rms_final_weight = fptr;
198
+ fptr += p->dim;
199
+
200
+ // now read all the quantized weights
201
+ ptr = (void*)fptr; // now cast the pointer back to void*
202
+ w->q_tokens = init_quantized_tensors(&ptr, 1, p->vocab_size * p->dim);
203
+ // dequantize token embedding table
204
+ w->token_embedding_table = malloc(p->vocab_size * p->dim * sizeof(float));
205
+ dequantize(w->q_tokens, w->token_embedding_table, p->vocab_size * p->dim);
206
+
207
+ w->wq = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_heads * head_size));
208
+ w->wk = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
209
+ w->wv = init_quantized_tensors(&ptr, p->n_layers, p->dim * (p->n_kv_heads * head_size));
210
+ w->wo = init_quantized_tensors(&ptr, p->n_layers, (p->n_heads * head_size) * p->dim);
211
+
212
+ w->w1 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
213
+ w->w2 = init_quantized_tensors(&ptr, p->n_layers, p->hidden_dim * p->dim);
214
+ w->w3 = init_quantized_tensors(&ptr, p->n_layers, p->dim * p->hidden_dim);
215
+
216
+ w->wcls = shared_classifier ? w->q_tokens : init_quantized_tensors(&ptr, 1, p->dim * p->vocab_size);
217
+ }
218
+
219
+ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
220
+ int* fd, float** data, ssize_t* file_size) {
221
+ FILE *file = fopen(checkpoint, "rb");
222
+ if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); exit(EXIT_FAILURE); }
223
+ // read in magic number (uint32), has to be 0x616b3432, i.e. "ak42" in ASCII
224
+ uint32_t magic_number;
225
+ if (fread(&magic_number, sizeof(uint32_t), 1, file) != 1) { exit(EXIT_FAILURE); }
226
+ if (magic_number != 0x616b3432) { fprintf(stderr, "Bad magic number\n"); exit(EXIT_FAILURE); }
227
+ // read in the version number (uint32), has to be 1
228
+ int version;
229
+ if (fread(&version, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); }
230
+ if (version != 2) { fprintf(stderr, "Bad version %d, need version 2\n", version); exit(EXIT_FAILURE); }
231
+ int header_size = 256; // the header size for version 2 in bytes
232
+ // read in the Config
233
+ if (fread(config, sizeof(Config), 1, file) != 1) { exit(EXIT_FAILURE); }
234
+ // read in flags
235
+ uint8_t shared_classifier; // a byte to indicate if the classifier is shared
236
+ if (fread(&shared_classifier, sizeof(uint8_t), 1, file) != 1) { exit(EXIT_FAILURE); }
237
+ int group_size; // the group size used in quantization
238
+ if (fread(&group_size, sizeof(int), 1, file) != 1) { exit(EXIT_FAILURE); }
239
+ GS = group_size; // set as global, as it will be used in many places
240
+ // figure out the file size
241
+ fseek(file, 0, SEEK_END); // move file pointer to end of file
242
+ *file_size = ftell(file); // get the file size, in bytes
243
+ fclose(file);
244
+ // memory map the Transformer weights into the data pointer
245
+ *fd = open(checkpoint, O_RDONLY); // open in read only mode
246
+ if (*fd == -1) { fprintf(stderr, "open failed!\n"); exit(EXIT_FAILURE); }
247
+ *data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
248
+ if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
249
+ void* weights_ptr = ((char*)*data) + header_size; // skip header bytes. char is 1 byte
250
+ memory_map_weights(weights, config, weights_ptr, shared_classifier);
251
+ }
252
+
253
+ void build_transformer(Transformer *t, char* checkpoint_path) {
254
+ // read in the Config and the Weights from the checkpoint
255
+ read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
256
+ // allocate the RunState buffers
257
+ malloc_run_state(&t->state, &t->config);
258
+ }
259
+
260
+ void free_transformer(Transformer* t) {
261
+ // free QuantizedTensors
262
+ free(t->weights.q_tokens);
263
+ free(t->weights.token_embedding_table);
264
+ free(t->weights.wq);
265
+ free(t->weights.wk);
266
+ free(t->weights.wv);
267
+ free(t->weights.wo);
268
+ free(t->weights.w1);
269
+ free(t->weights.w2);
270
+ free(t->weights.w3);
271
+ if(t->weights.wcls != t->weights.q_tokens) { free(t->weights.wcls); }
272
+ // close the memory mapping
273
+ if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
274
+ if (t->fd != -1) { close(t->fd); }
275
+ // free the RunState buffers
276
+ free_run_state(&t->state);
277
+ }
278
+
279
+ // ----------------------------------------------------------------------------
280
+ // neural net blocks; the dynamics of the Transformer
281
+
282
+ void rmsnorm(float* o, float* x, float* weight, int size) {
283
+ // calculate sum of squares
284
+ float ss = 0.0f;
285
+ for (int j = 0; j < size; j++) {
286
+ ss += x[j] * x[j];
287
+ }
288
+ ss /= size;
289
+ ss += 1e-5f;
290
+ ss = 1.0f / sqrtf(ss);
291
+ // normalize and scale
292
+ for (int j = 0; j < size; j++) {
293
+ o[j] = weight[j] * (ss * x[j]);
294
+ }
295
+ }
296
+
297
+ void softmax(float* x, int size) {
298
+ // find max value (for numerical stability)
299
+ float max_val = x[0];
300
+ for (int i = 1; i < size; i++) {
301
+ if (x[i] > max_val) {
302
+ max_val = x[i];
303
+ }
304
+ }
305
+ // exp and sum
306
+ float sum = 0.0f;
307
+ for (int i = 0; i < size; i++) {
308
+ x[i] = expf(x[i] - max_val);
309
+ sum += x[i];
310
+ }
311
+ // normalize
312
+ for (int i = 0; i < size; i++) {
313
+ x[i] /= sum;
314
+ }
315
+ }
316
+
317
+ void matmul(float* xout, QuantizedTensor *x, QuantizedTensor *w, int n, int d) {
318
+ // W (d,n) @ x (n,) -> xout (d,)
319
+ // by far the most amount of time is spent inside this little function
320
+ // inputs to this function are both quantized
321
+
322
+ int i;
323
+ #pragma omp parallel for private(i)
324
+ for (i = 0; i < d; i++) {
325
+
326
+ float val = 0.0f;
327
+ int32_t ival = 0;
328
+ int in = i * n;
329
+
330
+ // do the matmul in groups of GS
331
+ int j;
332
+ for (j = 0; j <= n - GS; j += GS) {
333
+ for (int k = 0; k < GS; k++) {
334
+ ival += ((int32_t) x->q[j + k]) * ((int32_t) w->q[in + j + k]);
335
+ }
336
+ val += ((float) ival) * w->s[(in + j) / GS] * x->s[j / GS];
337
+ ival = 0;
338
+ }
339
+
340
+ xout[i] = val;
341
+ }
342
+ }
343
+
344
+ float* forward(Transformer* transformer, int token, int pos) {
345
+
346
+ // a few convenience variables
347
+ Config* p = &transformer->config;
348
+ TransformerWeights* w = &transformer->weights;
349
+ RunState* s = &transformer->state;
350
+ float *x = s->x;
351
+ int dim = p->dim;
352
+ int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
353
+ int kv_mul = p->n_heads / p->n_kv_heads; // integer multiplier of the kv sharing in multiquery
354
+ int hidden_dim = p->hidden_dim;
355
+ int head_size = dim / p->n_heads;
356
+
357
+ // copy the token embedding into x
358
+ memcpy(x, w->token_embedding_table + token*dim, dim * sizeof(float));
359
+
360
+ // forward all the layers
361
+ for(int l = 0; l < p->n_layers; l++) {
362
+
363
+ // attention rmsnorm
364
+ rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
365
+
366
+ // qkv matmuls for this position
367
+ quantize(&s->xq, s->xb, dim);
368
+ matmul(s->q, &s->xq, w->wq + l, dim, dim);
369
+ matmul(s->k, &s->xq, w->wk + l, dim, kv_dim);
370
+ matmul(s->v, &s->xq, w->wv + l, dim, kv_dim);
371
+
372
+ // RoPE relative positional encoding: complex-valued rotate q and k in each head
373
+ for (int i = 0; i < dim; i+=2) {
374
+ int head_dim = i % head_size;
375
+ float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
376
+ float val = pos * freq;
377
+ float fcr = cosf(val);
378
+ float fci = sinf(val);
379
+ int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
380
+ for (int v = 0; v < rotn; v++) {
381
+ float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
382
+ float v0 = vec[i];
383
+ float v1 = vec[i+1];
384
+ vec[i] = v0 * fcr - v1 * fci;
385
+ vec[i+1] = v0 * fci + v1 * fcr;
386
+ }
387
+ }
388
+
389
+ // save key,value at this time step (pos) to our kv cache
390
+ int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
391
+ float* key_cache_row = s->key_cache + loff + pos * kv_dim;
392
+ float* value_cache_row = s->value_cache + loff + pos * kv_dim;
393
+ memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
394
+ memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
395
+
396
+ // multihead attention. iterate over all heads
397
+ int h;
398
+ #pragma omp parallel for private(h)
399
+ for (h = 0; h < p->n_heads; h++) {
400
+ // get the query vector for this head
401
+ float* q = s->q + h * head_size;
402
+ // attention scores for this head
403
+ float* att = s->att + h * p->seq_len;
404
+ // iterate over all timesteps, including the current one
405
+ for (int t = 0; t <= pos; t++) {
406
+ // get the key vector for this head and at this timestep
407
+ float* k = s->key_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
408
+ // calculate the attention score as the dot product of q and k
409
+ float score = 0.0f;
410
+ for (int i = 0; i < head_size; i++) {
411
+ score += q[i] * k[i];
412
+ }
413
+ score /= sqrtf(head_size);
414
+ // save the score to the attention buffer
415
+ att[t] = score;
416
+ }
417
+
418
+ // softmax the scores to get attention weights, from 0..pos inclusively
419
+ softmax(att, pos + 1);
420
+
421
+ // weighted sum of the values, store back into xb
422
+ float* xb = s->xb + h * head_size;
423
+ memset(xb, 0, head_size * sizeof(float));
424
+ for (int t = 0; t <= pos; t++) {
425
+ // get the value vector for this head and at this timestep
426
+ float* v = s->value_cache + loff + t * kv_dim + (h / kv_mul) * head_size;
427
+ // get the attention weight for this timestep
428
+ float a = att[t];
429
+ // accumulate the weighted value into xb
430
+ for (int i = 0; i < head_size; i++) {
431
+ xb[i] += a * v[i];
432
+ }
433
+ }
434
+ }
435
+
436
+ // final matmul to get the output of the attention
437
+ quantize(&s->xq, s->xb, dim);
438
+ matmul(s->xb2, &s->xq, w->wo + l, dim, dim);
439
+
440
+ // residual connection back into x
441
+ for (int i = 0; i < dim; i++) {
442
+ x[i] += s->xb2[i];
443
+ }
444
+
445
+ // ffn rmsnorm
446
+ rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
447
+
448
+ // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
449
+ // first calculate self.w1(x) and self.w3(x)
450
+ quantize(&s->xq, s->xb, dim);
451
+ matmul(s->hb, &s->xq, w->w1 + l, dim, hidden_dim);
452
+ matmul(s->hb2, &s->xq, w->w3 + l, dim, hidden_dim);
453
+
454
+ // SwiGLU non-linearity
455
+ for (int i = 0; i < hidden_dim; i++) {
456
+ float val = s->hb[i];
457
+ // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
458
+ val *= (1.0f / (1.0f + expf(-val)));
459
+ // elementwise multiply with w3(x)
460
+ val *= s->hb2[i];
461
+ s->hb[i] = val;
462
+ }
463
+
464
+ // final matmul to get the output of the ffn
465
+ quantize(&s->hq, s->hb, hidden_dim);
466
+ matmul(s->xb, &s->hq, w->w2 + l, hidden_dim, dim);
467
+
468
+ // residual connection
469
+ for (int i = 0; i < dim; i++) {
470
+ x[i] += s->xb[i];
471
+ }
472
+ }
473
+
474
+ // final rmsnorm
475
+ rmsnorm(x, x, w->rms_final_weight, dim);
476
+
477
+ // classifier into logits
478
+ quantize(&s->xq, x, dim);
479
+ matmul(s->logits, &s->xq, w->wcls, dim, p->vocab_size);
480
+ return s->logits;
481
+ }
482
+
483
+ // ----------------------------------------------------------------------------
484
+ // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens
485
+
486
+ typedef struct {
487
+ char *str;
488
+ int id;
489
+ } TokenIndex;
490
+
491
+ typedef struct {
492
+ char** vocab;
493
+ float* vocab_scores;
494
+ TokenIndex *sorted_vocab;
495
+ int vocab_size;
496
+ unsigned int max_token_length;
497
+ unsigned char byte_pieces[512]; // stores all single-byte strings
498
+ } Tokenizer;
499
+
500
+ int compare_tokens(const void *a, const void *b) {
501
+ return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
502
+ }
503
+
504
+ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
505
+ // i should have written the vocab_size into the tokenizer file... sigh
506
+ t->vocab_size = vocab_size;
507
+ // malloc space to hold the scores and the strings
508
+ t->vocab = (char**)malloc(vocab_size * sizeof(char*));
509
+ t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
510
+ t->sorted_vocab = NULL; // initialized lazily
511
+ for (int i = 0; i < 256; i++) {
512
+ t->byte_pieces[i * 2] = (unsigned char)i;
513
+ t->byte_pieces[i * 2 + 1] = '\0';
514
+ }
515
+ // read in the file
516
+ FILE *file = fopen(tokenizer_path, "rb");
517
+ if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
518
+ if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
519
+ int len;
520
+ for (int i = 0; i < vocab_size; i++) {
521
+ if (fread(t->vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE);}
522
+ if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
523
+ t->vocab[i] = (char *)malloc(len + 1);
524
+ if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
525
+ t->vocab[i][len] = '\0'; // add the string terminating token
526
+ }
527
+ fclose(file);
528
+ }
529
+
530
+ void free_tokenizer(Tokenizer* t) {
531
+ for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
532
+ free(t->vocab);
533
+ free(t->vocab_scores);
534
+ free(t->sorted_vocab);
535
+ }
536
+
537
+ char* decode(Tokenizer* t, int prev_token, int token) {
538
+ char *piece = t->vocab[token];
539
+ // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
540
+ if (prev_token == 1 && piece[0] == ' ') { piece++; }
541
+ // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
542
+ // parse this and convert and return the actual byte
543
+ unsigned char byte_val;
544
+ if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
545
+ piece = (char*)t->byte_pieces + byte_val * 2;
546
+ }
547
+ return piece;
548
+ }
549
+
550
+ void safe_printf(char *piece) {
551
+ // piece might be a raw byte token, and we only want to print printable chars or whitespace
552
+ // because some of the other bytes can be various control codes, backspace, etc.
553
+ if (piece == NULL) { return; }
554
+ if (piece[0] == '\0') { return; }
555
+ if (piece[1] == '\0') {
556
+ unsigned char byte_val = piece[0];
557
+ if (!(isprint(byte_val) || isspace(byte_val))) {
558
+ return; // bad byte, don't print it
559
+ }
560
+ }
561
+ printf("%s", piece);
562
+ }
563
+
564
+ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
565
+ // efficiently find the perfect match for str in vocab, return its index or -1 if not found
566
+ TokenIndex tok = { .str = str }; // acts as the key to search for
567
+ TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
568
+ return res != NULL ? res->id : -1;
569
+ }
570
+
571
+ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
572
+ // encode the string text (input) into an upper-bound preallocated tokens[] array
573
+ // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
574
+ if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
575
+
576
+ if (t->sorted_vocab == NULL) {
577
+ // lazily malloc and sort the vocabulary
578
+ t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
579
+ for (int i = 0; i < t->vocab_size; i++) {
580
+ t->sorted_vocab[i].str = t->vocab[i];
581
+ t->sorted_vocab[i].id = i;
582
+ }
583
+ qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
584
+ }
585
+
586
+ // create a temporary buffer that will store merge candidates of always two consecutive tokens
587
+ // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_length is 1)
588
+ char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char));
589
+ size_t str_len = 0;
590
+
591
+ // start at 0 tokens
592
+ *n_tokens = 0;
593
+
594
+ // add optional BOS (=1) token, if desired
595
+ if (bos) tokens[(*n_tokens)++] = 1;
596
+
597
+ // add_dummy_prefix is true by default
598
+ // so prepend a dummy prefix token to the input string, but only if text != ""
599
+ // TODO: pretty sure this isn't correct in the general case but I don't have the
600
+ // energy to read more of the sentencepiece code to figure out what it's doing
601
+ if (text[0] != '\0') {
602
+ int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
603
+ tokens[(*n_tokens)++] = dummy_prefix;
604
+ }
605
+
606
+ // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
607
+ // Code point ↔ UTF-8 conversion
608
+ // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
609
+ // U+0000 U+007F 0xxxxxxx
610
+ // U+0080 U+07FF 110xxxxx 10xxxxxx
611
+ // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
612
+ // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
613
+
614
+ // process the raw (UTF-8) byte sequence of the input string
615
+ for (char *c = text; *c != '\0'; c++) {
616
+
617
+ // reset buffer if the current byte is ASCII or a leading byte
618
+ // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the rest
619
+ // 0x80 is 10000000
620
+ // in UTF-8, all continuation bytes start with "10" in first two bits
621
+ // so in English this is: "if this byte is not a continuation byte"
622
+ if ((*c & 0xC0) != 0x80) {
623
+ // this byte must be either a leading byte (11...) or an ASCII char (0x...)
624
+ // => reset our location, as we're starting a new UTF-8 codepoint
625
+ str_len = 0;
626
+ }
627
+
628
+ // append the current byte to the buffer
629
+ str_buffer[str_len++] = *c; // ++ is post-increment, incremented after this line
630
+ str_buffer[str_len] = '\0';
631
+
632
+ // while the next character is a continuation byte, continue appending
633
+ // but if there are too many of them, just stop to avoid overruning str_buffer size.
634
+ if ((*(c+1) & 0xC0) == 0x80 && str_len < 4) {
635
+ continue;
636
+ }
637
+
638
+ // ok c+1 is not a continuation byte, so we've read in a full codepoint
639
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
640
+
641
+ if (id != -1) {
642
+ // we found this codepoint in vocab, add it as a token
643
+ tokens[(*n_tokens)++] = id;
644
+ } else {
645
+ // byte_fallback encoding: just encode each byte as a token
646
+ // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
647
+ // so the individual bytes only start at index 3
648
+ for (int i=0; i < str_len; i++) {
649
+ tokens[(*n_tokens)++] = (unsigned char)str_buffer[i] + 3;
650
+ }
651
+ }
652
+ str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
653
+ }
654
+
655
+ // merge the best consecutive pair each iteration, according the scores in vocab_scores
656
+ while (1) {
657
+ float best_score = -1e10;
658
+ int best_id = -1;
659
+ int best_idx = -1;
660
+
661
+ for (int i=0; i < (*n_tokens-1); i++) {
662
+ // check if we can merge the pair (tokens[i], tokens[i+1])
663
+ sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
664
+ int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
665
+ if (id != -1 && t->vocab_scores[id] > best_score) {
666
+ // this merge pair exists in vocab! record its score and position
667
+ best_score = t->vocab_scores[id];
668
+ best_id = id;
669
+ best_idx = i;
670
+ }
671
+ }
672
+
673
+ if (best_idx == -1) {
674
+ break; // we couldn't find any more pairs to merge, so we're done
675
+ }
676
+
677
+ // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
678
+ tokens[best_idx] = best_id;
679
+ // delete token at position best_idx+1, shift the entire sequence back 1
680
+ for (int i = best_idx+1; i < (*n_tokens-1); i++) {
681
+ tokens[i] = tokens[i+1];
682
+ }
683
+ (*n_tokens)--; // token length decreased
684
+ }
685
+
686
+ // add optional EOS (=2) token, if desired
687
+ if (eos) tokens[(*n_tokens)++] = 2;
688
+
689
+ free(str_buffer);
690
+ }
691
+
692
+ // ----------------------------------------------------------------------------
693
+ // The Sampler, which takes logits and returns a sampled token
694
+ // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
695
+
696
+ typedef struct {
697
+ float prob;
698
+ int index;
699
+ } ProbIndex; // struct used when sorting probabilities during top-p sampling
700
+
701
+ typedef struct {
702
+ int vocab_size;
703
+ ProbIndex* probindex; // buffer used in top-p sampling
704
+ float temperature;
705
+ float topp;
706
+ unsigned long long rng_state;
707
+ } Sampler;
708
+
709
+ int sample_argmax(float* probabilities, int n) {
710
+ // return the index that has the highest probability
711
+ int max_i = 0;
712
+ float max_p = probabilities[0];
713
+ for (int i = 1; i < n; i++) {
714
+ if (probabilities[i] > max_p) {
715
+ max_i = i;
716
+ max_p = probabilities[i];
717
+ }
718
+ }
719
+ return max_i;
720
+ }
721
+
722
+ int sample_mult(float* probabilities, int n, float coin) {
723
+ // sample index from probabilities (they must sum to 1!)
724
+ // coin is a random number in [0, 1), usually from random_f32()
725
+ float cdf = 0.0f;
726
+ for (int i = 0; i < n; i++) {
727
+ cdf += probabilities[i];
728
+ if (coin < cdf) {
729
+ return i;
730
+ }
731
+ }
732
+ return n - 1; // in case of rounding errors
733
+ }
734
+
735
+ int compare(const void* a, const void* b) {
736
+ ProbIndex* a_ = (ProbIndex*) a;
737
+ ProbIndex* b_ = (ProbIndex*) b;
738
+ if (a_->prob > b_->prob) return -1;
739
+ if (a_->prob < b_->prob) return 1;
740
+ return 0;
741
+ }
742
+
743
+ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
744
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
745
+ // tokens that exceed probability topp. This way we never sample tokens that
746
+ // have very low probabilities and are less likely to go "off the rails".
747
+ // coin is a random number in [0, 1), usually from random_f32()
748
+
749
+ int n0 = 0;
750
+ // quicksort indices in descending order of probabilities
751
+ // values smaller than (1 - topp) / (n - 1) cannot be part of the result
752
+ // so for efficiency we crop these out as candidates before sorting
753
+ const float cutoff = (1.0f - topp) / (n - 1);
754
+ for (int i = 0; i < n; i++) {
755
+ if (probabilities[i] >= cutoff) {
756
+ probindex[n0].index = i;
757
+ probindex[n0].prob = probabilities[i];
758
+ n0++;
759
+ }
760
+ }
761
+ qsort(probindex, n0, sizeof(ProbIndex), compare);
762
+
763
+ // truncate the list where cumulative probability exceeds topp
764
+ float cumulative_prob = 0.0f;
765
+ int last_idx = n0 - 1; // in case of rounding errors consider all elements
766
+ for (int i = 0; i < n0; i++) {
767
+ cumulative_prob += probindex[i].prob;
768
+ if (cumulative_prob > topp) {
769
+ last_idx = i;
770
+ break; // we've exceeded topp by including last_idx
771
+ }
772
+ }
773
+
774
+ // sample from the truncated list
775
+ float r = coin * cumulative_prob;
776
+ float cdf = 0.0f;
777
+ for (int i = 0; i <= last_idx; i++) {
778
+ cdf += probindex[i].prob;
779
+ if (r < cdf) {
780
+ return probindex[i].index;
781
+ }
782
+ }
783
+ return probindex[last_idx].index; // in case of rounding errors
784
+ }
785
+
786
+ void build_sampler(Sampler* sampler, int vocab_size, float temperature, float topp, unsigned long long rng_seed) {
787
+ sampler->vocab_size = vocab_size;
788
+ sampler->temperature = temperature;
789
+ sampler->topp = topp;
790
+ sampler->rng_state = rng_seed;
791
+ // buffer only used with nucleus sampling; may not need but it's ~small
792
+ sampler->probindex = malloc(sampler->vocab_size * sizeof(ProbIndex));
793
+ }
794
+
795
+ void free_sampler(Sampler* sampler) {
796
+ free(sampler->probindex);
797
+ }
798
+
799
+ unsigned int random_u32(unsigned long long *state) {
800
+ // xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
801
+ *state ^= *state >> 12;
802
+ *state ^= *state << 25;
803
+ *state ^= *state >> 27;
804
+ return (*state * 0x2545F4914F6CDD1Dull) >> 32;
805
+ }
806
+ float random_f32(unsigned long long *state) { // random float32 in [0,1)
807
+ return (random_u32(state) >> 8) / 16777216.0f;
808
+ }
809
+
810
+ int sample(Sampler* sampler, float* logits) {
811
+ // sample the token given the logits and some hyperparameters
812
+ int next;
813
+ if (sampler->temperature == 0.0f) {
814
+ // greedy argmax sampling: take the token with the highest probability
815
+ next = sample_argmax(logits, sampler->vocab_size);
816
+ } else {
817
+ // apply the temperature to the logits
818
+ for (int q=0; q<sampler->vocab_size; q++) { logits[q] /= sampler->temperature; }
819
+ // apply softmax to the logits to get the probabilities for next token
820
+ softmax(logits, sampler->vocab_size);
821
+ // flip a (float) coin (this is our source of entropy for sampling)
822
+ float coin = random_f32(&sampler->rng_state);
823
+ // we sample from this distribution to get the next token
824
+ if (sampler->topp <= 0 || sampler->topp >= 1) {
825
+ // simply sample from the predicted probability distribution
826
+ next = sample_mult(logits, sampler->vocab_size, coin);
827
+ } else {
828
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
829
+ next = sample_topp(logits, sampler->vocab_size, sampler->topp, sampler->probindex, coin);
830
+ }
831
+ }
832
+ return next;
833
+ }
834
+
835
+ // ----------------------------------------------------------------------------
836
+ // utilities: time
837
+
838
+ long time_in_ms() {
839
+ // return time in milliseconds, for benchmarking the model speed
840
+ struct timespec time;
841
+ clock_gettime(CLOCK_REALTIME, &time);
842
+ return time.tv_sec * 1000 + time.tv_nsec / 1000000;
843
+ }
844
+
845
+ // ----------------------------------------------------------------------------
846
+ // generation loop
847
+
848
+ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
849
+ char *empty_prompt = "";
850
+ if (prompt == NULL) { prompt = empty_prompt; }
851
+
852
+ // encode the (string) prompt into tokens sequence
853
+ int num_prompt_tokens = 0;
854
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int)); // +3 for '\0', ?BOS, ?EOS
855
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
856
+ if (num_prompt_tokens < 1) {
857
+ fprintf(stderr, "something is wrong, expected at least 1 prompt token\n");
858
+ exit(EXIT_FAILURE);
859
+ }
860
+
861
+ // start the main loop
862
+ long start = 0; // used to time our code, only initialized after first iteration
863
+ int next; // will store the next token in the sequence
864
+ int token = prompt_tokens[0]; // kick off with the first token in the prompt
865
+ int pos = 0; // position in the sequence
866
+ while (pos < steps) {
867
+
868
+ // forward the transformer to get logits for the next token
869
+ float* logits = forward(transformer, token, pos);
870
+
871
+ // advance the state state machine
872
+ if (pos < num_prompt_tokens - 1) {
873
+ // if we are still processing the input prompt, force the next prompt token
874
+ next = prompt_tokens[pos + 1];
875
+ } else {
876
+ // otherwise sample the next token from the logits
877
+ next = sample(sampler, logits);
878
+ }
879
+ pos++;
880
+
881
+ // data-dependent terminating condition: the BOS (=1) token delimits sequences
882
+ if (next == 1) { break; }
883
+
884
+ // print the token as string, decode it with the Tokenizer object
885
+ char* piece = decode(tokenizer, token, next);
886
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
887
+ fflush(stdout);
888
+ token = next;
889
+
890
+ // init the timer here because the first iteration can be slower
891
+ if (start == 0) { start = time_in_ms(); }
892
+ }
893
+ printf("\n");
894
+
895
+ // report achieved tok/s (pos-1 because the timer starts after first iteration)
896
+ if (pos > 1) {
897
+ long end = time_in_ms();
898
+ fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
899
+ }
900
+
901
+ free(prompt_tokens);
902
+ }
903
+
904
+ void read_stdin(const char* guide, char* buffer, size_t bufsize) {
905
+ // read a line from stdin, up to but not including \n
906
+ printf("%s", guide);
907
+ if (fgets(buffer, bufsize, stdin) != NULL) {
908
+ size_t len = strlen(buffer);
909
+ if (len > 0 && buffer[len - 1] == '\n') {
910
+ buffer[len - 1] = '\0'; // strip newline
911
+ }
912
+ }
913
+ }
914
+
915
+ // ----------------------------------------------------------------------------
916
+ // chat loop
917
+ // I manually inspected the tokens for a few chat conversations compared to
918
+ // python reference and that seemed ok, but this was not thoroughly tested and
919
+ // is not safely implemented, it's more a proof of concept atm.
920
+
921
+ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
922
+ char *cli_user_prompt, char *cli_system_prompt, int steps) {
923
+
924
+ // buffers for reading the system prompt and user prompt from stdin
925
+ // you'll notice they are soomewhat haphazardly and unsafely set atm
926
+ char system_prompt[512];
927
+ char user_prompt[512];
928
+ char rendered_prompt[1152];
929
+ int num_prompt_tokens = 0;
930
+ int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
931
+ int user_idx;
932
+
933
+ // start the main loop
934
+ int8_t user_turn = 1; // user starts
935
+ int next; // will store the next token in the sequence
936
+ int token; // stores the current token to feed into the transformer
937
+ int prev_token;
938
+ int pos = 0; // position in the sequence
939
+ while (pos < steps) {
940
+
941
+ // when it is the user's turn to contribute tokens to the dialog...
942
+ if (user_turn) {
943
+ // get the (optional) system prompt at position 0
944
+ if (pos == 0) {
945
+ // at position 0, the user can also contribute a system prompt
946
+ if (cli_system_prompt == NULL) {
947
+ // system prompt was not passed in, attempt to get it from stdin
948
+ read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
949
+ } else {
950
+ // system prompt was passed in, use it
951
+ strcpy(system_prompt, cli_system_prompt);
952
+ }
953
+ }
954
+ // get the user prompt
955
+ if (pos == 0 && cli_user_prompt != NULL) {
956
+ // user prompt for position 0 was passed in, use it
957
+ strcpy(user_prompt, cli_user_prompt);
958
+ } else {
959
+ // otherwise get user prompt from stdin
960
+ read_stdin("User: ", user_prompt, sizeof(user_prompt));
961
+ }
962
+ // render user/system prompts into the Llama 2 Chat schema
963
+ if (pos == 0 && system_prompt[0] != '\0') {
964
+ char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
965
+ sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
966
+ } else {
967
+ char user_template[] = "[INST] %s [/INST]";
968
+ sprintf(rendered_prompt, user_template, user_prompt);
969
+ }
970
+ // encode the rendered prompt into tokens
971
+ encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
972
+ user_idx = 0; // reset the user index
973
+ user_turn = 0;
974
+ printf("Assistant: ");
975
+ }
976
+
977
+ // determine the token to pass into the transformer next
978
+ if (user_idx < num_prompt_tokens) {
979
+ // if we are still processing the input prompt, force the next prompt token
980
+ token = prompt_tokens[user_idx++];
981
+ } else {
982
+ // otherwise use the next token sampled from previous turn
983
+ token = next;
984
+ }
985
+ // EOS (=2) token ends the Assistant turn
986
+ if (token == 2) { user_turn = 1; }
987
+
988
+ // forward the transformer to get logits for the next token
989
+ float* logits = forward(transformer, token, pos);
990
+ next = sample(sampler, logits);
991
+ pos++;
992
+
993
+ if (user_idx >= num_prompt_tokens && next != 2) {
994
+ // the Assistant is responding, so print its output
995
+ char* piece = decode(tokenizer, token, next);
996
+ safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
997
+ fflush(stdout);
998
+ }
999
+ if (next == 2) { printf("\n"); }
1000
+ }
1001
+ printf("\n");
1002
+ free(prompt_tokens);
1003
+ }
1004
+
1005
+
1006
+ // ----------------------------------------------------------------------------
1007
+ // CLI, include only if not testing
1008
+ #ifndef TESTING
1009
+
1010
+ void error_usage() {
1011
+ fprintf(stderr, "Usage: run <checkpoint> [options]\n");
1012
+ fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
1013
+ fprintf(stderr, "Options:\n");
1014
+ fprintf(stderr, " -t <float> temperature in [0,inf], default 1.0\n");
1015
+ fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling in [0,1] default 0.9\n");
1016
+ fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
1017
+ fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
1018
+ fprintf(stderr, " -i <string> input prompt\n");
1019
+ fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
1020
+ fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
1021
+ fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
1022
+ exit(EXIT_FAILURE);
1023
+ }
1024
+
1025
+ int main(int argc, char *argv[]) {
1026
+
1027
+ // default parameters
1028
+ char *checkpoint_path = NULL; // e.g. out/model.bin
1029
+ char *tokenizer_path = "tokenizer.bin";
1030
+ float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
1031
+ float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
1032
+ int steps = 256; // number of steps to run for
1033
+ char *prompt = NULL; // prompt string
1034
+ unsigned long long rng_seed = 0; // seed rng with time by default
1035
+ char *mode = "generate"; // generate|chat
1036
+ char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
1037
+
1038
+ // poor man's C argparse so we can override the defaults above from the command line
1039
+ if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
1040
+ for (int i = 2; i < argc; i+=2) {
1041
+ // do some basic validation
1042
+ if (i + 1 >= argc) { error_usage(); } // must have arg after flag
1043
+ if (argv[i][0] != '-') { error_usage(); } // must start with dash
1044
+ if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
1045
+ // read in the args
1046
+ if (argv[i][1] == 't') { temperature = atof(argv[i + 1]); }
1047
+ else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
1048
+ else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
1049
+ else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
1050
+ else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
1051
+ else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
1052
+ else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
1053
+ else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
1054
+ else { error_usage(); }
1055
+ }
1056
+
1057
+ // parameter validation/overrides
1058
+ if (rng_seed <= 0) rng_seed = (unsigned int)time(NULL);
1059
+ if (temperature < 0.0) temperature = 0.0;
1060
+ if (topp < 0.0 || 1.0 < topp) topp = 0.9;
1061
+ if (steps < 0) steps = 0;
1062
+
1063
+ // build the Transformer via the model .bin file
1064
+ Transformer transformer;
1065
+ build_transformer(&transformer, checkpoint_path);
1066
+ if (steps == 0 || steps > transformer.config.seq_len) steps = transformer.config.seq_len; // ovrerride to ~max length
1067
+
1068
+ // build the Tokenizer via the tokenizer .bin file
1069
+ Tokenizer tokenizer;
1070
+ build_tokenizer(&tokenizer, tokenizer_path, transformer.config.vocab_size);
1071
+
1072
+ // build the Sampler
1073
+ Sampler sampler;
1074
+ build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
1075
+
1076
+ // run!
1077
+ if (strcmp(mode, "generate") == 0) {
1078
+ generate(&transformer, &tokenizer, &sampler, prompt, steps);
1079
+ } else if (strcmp(mode, "chat") == 0) {
1080
+ chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
1081
+ } else {
1082
+ fprintf(stderr, "unknown mode: %s\n", mode);
1083
+ error_usage();
1084
+ }
1085
+
1086
+ // memory and file handles cleanup
1087
+ free_sampler(&sampler);
1088
+ free_tokenizer(&tokenizer);
1089
+ free_transformer(&transformer);
1090
+ return 0;
1091
+ }
1092
+ #endif
sample.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample from the trained model with PyTorch
3
+ """
4
+ import os
5
+ import pickle
6
+ from contextlib import nullcontext
7
+ import torch
8
+ from model import ModelArgs, Transformer
9
+ from tokenizer import Tokenizer
10
+
11
+ from tinystories import get_tokenizer_model_path
12
+
13
+ # -----------------------------------------------------------------------------
14
+ checkpoint = 'out/ckpt.pt'
15
+ start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
16
+ num_samples = 1 # number of samples to draw
17
+ max_new_tokens = 100 # number of tokens generated in each sample
18
+ temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
19
+ top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability
20
+ tokenizer = "" # override the tokenizer model path
21
+ seed = 1337
22
+ device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
23
+ #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
24
+ dtype = "float32"
25
+ compile = False # use PyTorch 2.0 to compile the model to be faster
26
+ exec(open('configurator.py').read()) # overrides from command line or config file
27
+ # -----------------------------------------------------------------------------
28
+
29
+ torch.manual_seed(seed)
30
+ torch.cuda.manual_seed(seed)
31
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
32
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
33
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
34
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
35
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
36
+
37
+ # init from a model saved in a specific directory
38
+ checkpoint_dict = torch.load(checkpoint, map_location=device)
39
+ gptconf = ModelArgs(**checkpoint_dict['model_args'])
40
+ model = Transformer(gptconf)
41
+ state_dict = checkpoint_dict['model']
42
+ unwanted_prefix = '_orig_mod.'
43
+ for k,v in list(state_dict.items()):
44
+ if k.startswith(unwanted_prefix):
45
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
46
+ model.load_state_dict(state_dict, strict=False)
47
+
48
+ model.eval()
49
+ model.to(device)
50
+ if compile:
51
+ print("Compiling the model...")
52
+ model = torch.compile(model) # requires PyTorch 2.0 (optional)
53
+
54
+ # load the tokenizer
55
+ vocab_source = checkpoint_dict["config"].get("vocab_source", "llama2")
56
+ vocab_size = gptconf.vocab_size
57
+ if tokenizer:
58
+ # a specific tokenizer is provided, use it
59
+ tokenizer_model = tokenizer
60
+ else:
61
+ # let's try to find the tokenizer model automatically. bit gross here...
62
+ query_vocab_size = 0 if vocab_source == "llama2" else vocab_size
63
+ tokenizer_model = get_tokenizer_model_path(vocab_size=query_vocab_size)
64
+ enc = Tokenizer(tokenizer_model=tokenizer_model)
65
+
66
+ # encode the beginning of the prompt
67
+ if start.startswith('FILE:'):
68
+ with open(start[5:], 'r', encoding='utf-8') as f:
69
+ start = f.read()
70
+ start_ids = enc.encode(start, bos=True, eos=False)
71
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
72
+
73
+ # run generation
74
+ with torch.no_grad():
75
+ with ctx:
76
+ for k in range(num_samples):
77
+ y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
78
+ print(enc.decode(y[0].tolist()))
79
+ print('---------------')
test.c ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define TESTING
2
+ #include "run.c"
3
+
4
+ void assert_eq(int a, int b) {
5
+ if (a != b) {
6
+ printf("Assertion failed: %d != %d\n", a, b);
7
+ exit(EXIT_FAILURE);
8
+ }
9
+ }
10
+
11
+ void test_prompt_encoding(Tokenizer* tokenizer, char* prompt, int* expected_tokens, int num_expected_tokens) {
12
+ // encode
13
+ int* prompt_tokens = (int*)malloc((strlen(prompt)+3) * sizeof(int));
14
+ int num_prompt_tokens = 0; // the total number of prompt tokens
15
+ encode(tokenizer, prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
16
+
17
+ #if VERBOSITY == 1
18
+ // print maybe
19
+ printf("expected tokens:\n");
20
+ for (int i = 0; i < num_expected_tokens; i++) printf("%d ", expected_tokens[i]);
21
+ printf("\n");
22
+ printf("actual tokens:\n");
23
+ for (int i = 0; i < num_prompt_tokens; i++) printf("%d ", prompt_tokens[i]);
24
+ printf("\n");
25
+ #endif
26
+
27
+ // verify
28
+ assert_eq(num_prompt_tokens, num_expected_tokens);
29
+ for (int i = 0; i < num_prompt_tokens; i++) {
30
+ assert_eq(prompt_tokens[i], expected_tokens[i]);
31
+ }
32
+
33
+ #if VERBOSITY == 1
34
+ printf("OK\n");
35
+ printf("---\n");
36
+ #endif
37
+ free(prompt_tokens);
38
+ }
39
+
40
+ void test_prompt_encodings() {
41
+ // let's verify that the Tokenizer works as expected
42
+
43
+ char *tokenizer_path = "tokenizer.bin";
44
+ int vocab_size = 32000;
45
+ Tokenizer tokenizer;
46
+ build_tokenizer(&tokenizer, tokenizer_path, vocab_size);
47
+
48
+ // test 0 (test the empty string) (I added this as a simple case)
49
+ char *prompt0 = "";
50
+ int expected_tokens0[] = {1};
51
+ test_prompt_encoding(&tokenizer, prompt0, expected_tokens0, sizeof(expected_tokens0) / sizeof(int));
52
+
53
+ // the tests below are taken from the Meta Llama 2 repo example code
54
+ // https://github.com/facebookresearch/llama/blob/main/example_text_completion.py
55
+ // and the expected tokens come from me breaking in the debugger in Python
56
+
57
+ // test 1
58
+ char *prompt = "I believe the meaning of life is";
59
+ int expected_tokens[] = {1, 306, 4658, 278, 6593, 310, 2834, 338};
60
+ test_prompt_encoding(&tokenizer, prompt, expected_tokens, sizeof(expected_tokens) / sizeof(int));
61
+
62
+ // test 2
63
+ char* prompt2 = "Simply put, the theory of relativity states that ";
64
+ int expected_tokens2[] = {1, 3439, 17632, 1925, 29892, 278, 6368, 310, 14215, 537, 5922, 393, 29871};
65
+ test_prompt_encoding(&tokenizer, prompt2, expected_tokens2, sizeof(expected_tokens2) / sizeof(int));
66
+
67
+ // test 3
68
+ char* prompt3 = "A brief message congratulating the team on the launch:\n\n Hi everyone,\n\n I just ";
69
+ int expected_tokens3[] = {1, 319, 11473, 2643, 378, 629, 271, 18099, 278, 3815, 373, 278, 6826, 29901, 13, 13, 4706, 6324, 14332, 29892, 13, 13, 4706, 306, 925, 29871};
70
+ test_prompt_encoding(&tokenizer, prompt3, expected_tokens3, sizeof(expected_tokens3) / sizeof(int));
71
+
72
+ // test 4
73
+ char* prompt4 = "Translate English to French:\n\n sea otter => loutre de mer\n peppermint => menthe poivrée\n plush girafe => girafe peluche\n cheese =>";
74
+ int expected_tokens4[] = {1, 4103, 9632, 4223, 304, 5176, 29901, 13, 13, 4706, 7205, 4932, 357, 1149, 301, 449, 276, 316, 2778, 13, 4706, 1236, 407, 837, 524, 1149, 6042, 354, 772, 440, 29878, 1318, 13, 4706, 715, 1878, 330, 3055, 1725, 1149, 330, 3055, 1725, 4639, 28754, 13, 4706, 923, 968, 1149};
75
+ test_prompt_encoding(&tokenizer, prompt4, expected_tokens4, sizeof(expected_tokens4) / sizeof(int));
76
+
77
+ // memory and file handles cleanup
78
+ free_tokenizer(&tokenizer);
79
+ }
80
+
81
+ int main(int argc, char *argv[]) {
82
+ test_prompt_encodings();
83
+ printf("ALL OK\n");
84
+ }
test_all.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run simply with
3
+ $ pytest
4
+ """
5
+ import os
6
+ import pytest # pip install pytest
7
+ import requests
8
+ import subprocess
9
+
10
+
11
+ import torch
12
+ from model import ModelArgs, Transformer
13
+ from tokenizer import Tokenizer
14
+
15
+ # -----------------------------------------------------------------------------
16
+ # test utilities
17
+
18
+ test_ckpt_dir = "test"
19
+
20
+ def download_file(url, filename):
21
+ print(f"Downloading {url} to {filename}")
22
+ response = requests.get(url, stream=True)
23
+ response.raise_for_status() # Raise an HTTPError on bad status code
24
+ with open(filename, 'wb') as file:
25
+ for chunk in response.iter_content(chunk_size=8192):
26
+ file.write(chunk)
27
+
28
+ def attempt_download_files():
29
+ os.makedirs(test_ckpt_dir, exist_ok=True)
30
+ root_url = "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K"
31
+ need = ["stories260K.bin", "stories260K.pt", "tok512.bin", "tok512.model"]
32
+ for file in need:
33
+ url = root_url + '/' + file #os.path.join inserts \\ on windows
34
+ filename = os.path.join(test_ckpt_dir, file)
35
+ if not os.path.exists(filename):
36
+ download_file(url, filename)
37
+
38
+ expected_stdout = b'Once upon a time, there was a little girl named Lily. She loved to play outside in the park. One day, she saw a big, red ball. She wanted to play with it, but it was too high.\nLily\'s mom said, "Lily, let\'s go to the park." Lily was sad and didn\'t know what to do. She said, "I want to play with your ball, but I can\'t find it."\nLily was sad and didn\'t know what to do. She said, "I\'m sorry, Lily. I didn\'t know what to do."\nLily didn\'t want to help her mom, so she'
39
+
40
+ # -----------------------------------------------------------------------------
41
+ # actual tests
42
+
43
+ def test_runc():
44
+ """ Forwards a model against a known-good desired outcome in run.c for 200 steps"""
45
+ attempt_download_files()
46
+
47
+ model_path = os.path.join(test_ckpt_dir, "stories260K.bin")
48
+ tokenizer_path = os.path.join(test_ckpt_dir, "tok512.bin")
49
+ command = ["./run", model_path, "-z", tokenizer_path, "-t", "0.0", "-n", "200"]
50
+ with open('err.txt', mode='wb') as fe:
51
+ with open('stdout.txt', mode='wb') as fo:
52
+ proc = subprocess.Popen(command, stdout=fo, stderr=fe) #pipe in windows terminal does funny things like replacing \n with \r\n
53
+ proc.wait()
54
+
55
+ with open('stdout.txt', mode='r') as f:
56
+ stdout = f.read()
57
+ # strip the very last \n that is added by run.c for aesthetic reasons
58
+ stdout = stdout[:-1].encode('ascii')
59
+
60
+ assert stdout == expected_stdout
61
+
62
+ def test_python():
63
+ """ Forwards a model against a known-good desired outcome in sample.py for 200 steps"""
64
+ attempt_download_files()
65
+
66
+ device = "cpu" # stories260K is small enough to just breeze through it on CPU
67
+ checkpoint = os.path.join(test_ckpt_dir, "stories260K.pt")
68
+ checkpoint_dict = torch.load(checkpoint, map_location=device)
69
+ gptconf = ModelArgs(**checkpoint_dict['model_args'])
70
+ model = Transformer(gptconf)
71
+ state_dict = checkpoint_dict['model']
72
+ unwanted_prefix = '_orig_mod.'
73
+ for k,v in list(state_dict.items()):
74
+ if k.startswith(unwanted_prefix):
75
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
76
+ model.load_state_dict(state_dict, strict=False)
77
+ model.eval()
78
+ model.to(device)
79
+ x = torch.tensor([[1]], dtype=torch.long, device=device) # 1 is BOS
80
+ with torch.inference_mode():
81
+ y = model.generate(x, max_new_tokens=200, temperature=0.0)
82
+ pt_tokens = y[0].tolist()
83
+
84
+ tokenizer_model = os.path.join(test_ckpt_dir, "tok512.model")
85
+ enc = Tokenizer(tokenizer_model=tokenizer_model)
86
+ text = enc.decode(pt_tokens)
87
+ text = text.encode('ascii') # turn into bytes
88
+
89
+ assert text == expected_stdout
tinystories.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download, preprocess and serve the TinyStories dataset as a DataLoader.
3
+ """
4
+
5
+ import argparse
6
+ import glob
7
+ import json
8
+ import os
9
+ import random
10
+ from typing import List
11
+ from concurrent.futures import ProcessPoolExecutor
12
+ from functools import partial
13
+
14
+ import numpy as np
15
+ import requests
16
+ import sentencepiece as spm
17
+ import torch
18
+ import torch.distributed as dist
19
+ from tqdm import tqdm
20
+
21
+ from tokenizer import Tokenizer
22
+
23
+ DATA_CACHE_DIR = "data"
24
+
25
+ def download_file(url: str, fname: str, chunk_size=1024):
26
+ """Helper function to download a file from a given url"""
27
+ resp = requests.get(url, stream=True)
28
+ total = int(resp.headers.get("content-length", 0))
29
+ with open(fname, "wb") as file, tqdm(
30
+ desc=fname,
31
+ total=total,
32
+ unit="iB",
33
+ unit_scale=True,
34
+ unit_divisor=1024,
35
+ ) as bar:
36
+ for data in resp.iter_content(chunk_size=chunk_size):
37
+ size = file.write(data)
38
+ bar.update(size)
39
+
40
+
41
+ def download():
42
+ """Downloads the TinyStories dataset to DATA_CACHE_DIR"""
43
+ os.makedirs(DATA_CACHE_DIR, exist_ok=True)
44
+
45
+ # download the TinyStories dataset, unless it's already downloaded
46
+ data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz"
47
+ data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz")
48
+ if not os.path.exists(data_filename):
49
+ print(f"Downloading {data_url} to {data_filename}...")
50
+ download_file(data_url, data_filename)
51
+ else:
52
+ print(f"{data_filename} already exists, skipping download...")
53
+
54
+ # unpack the tar.gz file into all the data shards (json files)
55
+ data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
56
+ if not os.path.exists(data_dir):
57
+ os.makedirs(data_dir, exist_ok=True)
58
+ print(f"Unpacking {data_filename}...")
59
+ os.system(f"tar -xzf {data_filename} -C {data_dir}")
60
+ else:
61
+ print(f"{data_dir} already exists, skipping unpacking...")
62
+
63
+ # print a single example just for debugging and such
64
+ shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
65
+ with open(shard_filenames[0], "r") as f:
66
+ data = json.load(f)
67
+ print("Download done.")
68
+ print(f"Number of shards: {len(shard_filenames)}")
69
+ print(f"Example story:\n{data[0]}")
70
+
71
+ def train_vocab(vocab_size):
72
+ """
73
+ Trains a custom sentencepiece tokenizer on the TinyStories dataset.
74
+ The custom tokenizer files will be saved in DATA_CACHE_DIR/tok{N} directories,
75
+ where N is the vocab size. This is also where the pretok .bin files will go.
76
+ """
77
+ assert vocab_size > 0, "Vocab size must be positive"
78
+
79
+ # output file prefix path for sentencepiece
80
+ prefix = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
81
+
82
+ # how many shards we'll use for vocab training, kept low for efficiency
83
+ num_shards = 10
84
+
85
+ # 1) export a large chunk of text as a single text file tiny.txt
86
+ tiny_file = os.path.join(DATA_CACHE_DIR, "tiny.txt")
87
+ data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
88
+ shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
89
+
90
+ print(f"Writing temporary file {tiny_file} with {num_shards} shards...")
91
+ with open(tiny_file, "w", encoding="utf-8") as of:
92
+ for shard in tqdm(shard_filenames[:num_shards]):
93
+ with open(shard, "r") as f:
94
+ data = json.load(f)
95
+ for example in data:
96
+ text = example["story"]
97
+ text = text.strip()
98
+ of.write(text + "\n")
99
+ print(f"Size is: {os.path.getsize(tiny_file) / 1024 / 1024:.2f} MB")
100
+
101
+ # 2) train the sentencepiece model
102
+ print("Will now train the vocab...")
103
+ spm.SentencePieceTrainer.train(input=tiny_file,
104
+ model_prefix=prefix,
105
+ model_type="bpe",
106
+ vocab_size=vocab_size,
107
+ self_test_sample_size=0,
108
+ input_format="text",
109
+ character_coverage=1.0,
110
+ num_threads=os.cpu_count(),
111
+ split_digits=True,
112
+ allow_whitespace_only_pieces=True,
113
+ byte_fallback=True,
114
+ unk_surface=r" \342\201\207 ",
115
+ normalization_rule_name="identity")
116
+
117
+ # 3) optional cleanup, ask the user if they'd like to delete tiny.txt
118
+ dec = input(f"Delete the temporary file {tiny_file}? [y/N] ")
119
+ if dec.lower() == "y":
120
+ os.remove(tiny_file)
121
+ print(f"Deleted {tiny_file}")
122
+
123
+ print(f"Trained tokenizer is in {prefix}.model")
124
+ print("Done.")
125
+
126
+
127
+ def process_shard(args, vocab_size):
128
+ shard_id, shard = args
129
+ tokenizer_model = get_tokenizer_model_path(vocab_size)
130
+ enc = Tokenizer(tokenizer_model)
131
+ with open(shard, "r") as f:
132
+ data = json.load(f)
133
+ all_tokens = []
134
+ for example in tqdm(data, position=shard_id):
135
+ text = example["story"]
136
+ text = text.strip() # get rid of leading/trailing whitespace
137
+ tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS
138
+ all_tokens.extend(tokens)
139
+ # convert to uint16 nparray
140
+ all_tokens = np.array(all_tokens, dtype=np.uint16)
141
+ # calculate the output filename
142
+ if vocab_size == 0:
143
+ # if we're using Llama 2, just save the tokenized file in the same dir
144
+ tokenized_filename = shard.replace(".json", ".bin")
145
+ else:
146
+ # save .bin files into a new tok{N} directory
147
+ bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
148
+ shard_basename = os.path.basename(shard)
149
+ bin_basename = shard_basename.replace(".json", ".bin")
150
+ tokenized_filename = os.path.join(bin_dir, bin_basename)
151
+ # write the bytes
152
+ with open(tokenized_filename, "wb") as f:
153
+ f.write(all_tokens.tobytes())
154
+ # calculate the average sequence length (they are separated by BOS=1)
155
+ avg_seq_len = all_tokens.size / ((all_tokens == 1).sum())
156
+ print(f"Saved {tokenized_filename}, average seqlen: {avg_seq_len:.2f}")
157
+
158
+
159
+ def pretokenize(vocab_size):
160
+ # iterate the shards and tokenize all of them one by one
161
+ data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
162
+ shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json")))
163
+ if vocab_size > 0:
164
+ # .bin files will be saved into tok{N} directory, create it once here
165
+ bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}")
166
+ os.makedirs(bin_dir, exist_ok=True)
167
+
168
+ # process all the shards in a process pool
169
+ fun = partial(process_shard, vocab_size=vocab_size)
170
+ with ProcessPoolExecutor() as executor:
171
+ executor.map(fun, enumerate(shard_filenames))
172
+ print("Done.")
173
+
174
+
175
+ class PretokDataset(torch.utils.data.IterableDataset):
176
+ """Loads pretokenized examples from disk and yields them as PyTorch tensors."""
177
+
178
+ def __init__(self, split, max_seq_len, vocab_size, vocab_source):
179
+ super().__init__()
180
+ self.split = split
181
+ self.max_seq_len = max_seq_len
182
+ self.vocab_size = vocab_size
183
+ self.vocab_source = vocab_source
184
+
185
+ def __iter__(self):
186
+ # get worker info within a DataLoader
187
+ worker_info = torch.utils.data.get_worker_info()
188
+ worker_id = worker_info.id if worker_info else 0
189
+ # get DDP rank info
190
+ rank = dist.get_rank() if dist.is_initialized() else 0
191
+ # combine the worker_id and worker_rank to create a unique seed for rng
192
+ seed = 42 + worker_id + 1337 * rank
193
+ rng = random.Random(seed)
194
+ print(f"Created a PretokDataset with rng seed {seed}")
195
+ if self.vocab_source == "llama2":
196
+ # the .bin files are right along the .json files
197
+ bin_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data")
198
+ shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
199
+ elif self.vocab_source == "custom":
200
+ # the .bin files are in tok{N} directory
201
+ bin_dir = os.path.join(DATA_CACHE_DIR, f"tok{self.vocab_size}")
202
+ shard_filenames = sorted(glob.glob(os.path.join(bin_dir, "*.bin")))
203
+ # train/test split. let's use only shard 0 for test split, rest train
204
+ shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1]
205
+ assert len(shard_filenames)>0, f"No bin files found in {bin_dir}"
206
+ while True:
207
+ rng.shuffle(shard_filenames)
208
+ for shard in shard_filenames:
209
+ # open the dataset for reading but keep it on disk with memmap
210
+ m = np.memmap(shard, dtype=np.uint16, mode="r")
211
+ num_batches = len(m) // self.max_seq_len
212
+ num_batches -= 1 # drop the last partial batch
213
+ assert num_batches > 0, "this shard is way too small? investigate."
214
+ ixs = list(range(num_batches))
215
+ rng.shuffle(ixs)
216
+ for ix in ixs:
217
+ start = ix * self.max_seq_len
218
+ end = start + self.max_seq_len + 1
219
+ # calling .astype will copy the data into a new numpy array, now in RAM
220
+ chunk = torch.from_numpy((m[start:end]).astype(np.int64))
221
+ x = chunk[:-1]
222
+ y = chunk[1:]
223
+ yield x, y
224
+
225
+ # -----------------------------------------------------------------------------
226
+ # public interface functions
227
+
228
+ def get_tokenizer_model_path(vocab_size):
229
+ """
230
+ Returns path to the sentencepiece tokenizer model for a given vocab size
231
+ vocab_size = 0 designates the default Llama 2 tokenizer, in that case
232
+ None is returned.
233
+ """
234
+ if vocab_size == 0:
235
+ return None
236
+ else:
237
+ return os.path.join(DATA_CACHE_DIR, f"tok{vocab_size}.model")
238
+
239
+ class Task:
240
+
241
+ @staticmethod
242
+ def iter_batches(batch_size, device, num_workers=0, **dataset_kwargs):
243
+ ds = PretokDataset(**dataset_kwargs)
244
+ dl = torch.utils.data.DataLoader(
245
+ ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers
246
+ )
247
+ for x, y in dl:
248
+ x = x.to(device, non_blocking=True)
249
+ y = y.to(device, non_blocking=True)
250
+ yield x, y
251
+
252
+ # -----------------------------------------------------------------------------
253
+ # CLI for constructing the dataset
254
+
255
+ if __name__ == "__main__":
256
+ """
257
+ These stages are designed to be run in order.
258
+
259
+ To tokenize data with the Llama 2 tokenizer:
260
+ python tinystories.py download
261
+ python tinystories.py pretokenize
262
+
263
+ To tokenize data with a custom tokenizer we train ourselves with sentencepiece, e.g.:
264
+ python tinystories.py download
265
+ python tinystories.py train_vocab --vocab_size=2048
266
+ python tinystories.py pretokenize --vocab_size=2048
267
+ """
268
+ parser = argparse.ArgumentParser()
269
+ parser.add_argument("stage", type=str, choices=["download", "pretokenize", "train_vocab"])
270
+ parser.add_argument("--vocab_size", type=int, default=0, help="pretokenization vocab size. 0 = use Llama 2 tokenizer.")
271
+ args = parser.parse_args()
272
+
273
+ # depending on the stage call the appropriate function
274
+ if args.stage == "download":
275
+ download()
276
+ elif args.stage == "train_vocab":
277
+ train_vocab(vocab_size=args.vocab_size)
278
+ elif args.stage == "pretokenize":
279
+ pretokenize(vocab_size=args.vocab_size)
280
+ else:
281
+ raise ValueError(f"Unknown stage {args.stage}")
tokenizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50a52ef822ee9e83de5ce9d0be0a025a773d019437f58b5ff9dcafb063ece361
3
+ size 433869
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from llama code and lightly modified
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
4
+
5
+ import os
6
+ import struct
7
+ import argparse
8
+ from typing import List
9
+
10
+ from sentencepiece import SentencePieceProcessor
11
+
12
+ TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model
13
+
14
+ class Tokenizer:
15
+ def __init__(self, tokenizer_model=None):
16
+ model_path = tokenizer_model if tokenizer_model else TOKENIZER_MODEL
17
+ assert os.path.isfile(model_path), model_path
18
+ self.sp_model = SentencePieceProcessor(model_file=model_path)
19
+ self.model_path = model_path
20
+
21
+ # BOS / EOS token IDs
22
+ self.n_words: int = self.sp_model.vocab_size()
23
+ self.bos_id: int = self.sp_model.bos_id()
24
+ self.eos_id: int = self.sp_model.eos_id()
25
+ self.pad_id: int = self.sp_model.pad_id()
26
+ #print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}")
27
+ assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
28
+
29
+ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
30
+ assert type(s) is str
31
+ t = self.sp_model.encode(s)
32
+ if bos:
33
+ t = [self.bos_id] + t
34
+ if eos:
35
+ t = t + [self.eos_id]
36
+ return t
37
+
38
+ def decode(self, t: List[int]) -> str:
39
+ return self.sp_model.decode(t)
40
+
41
+ def export(self):
42
+
43
+ # get all the tokens (postprocessed) and their scores as floats
44
+ tokens, scores = [], []
45
+ for i in range(self.n_words):
46
+
47
+ # decode the token and light postprocessing
48
+ t = self.sp_model.id_to_piece(i)
49
+ s = self.sp_model.get_score(i)
50
+ if i == self.bos_id:
51
+ t = '\n<s>\n'
52
+ elif i == self.eos_id:
53
+ t = '\n</s>\n'
54
+ t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
55
+ b = t.encode('utf-8') # bytes of this token, utf-8 encoded
56
+
57
+ tokens.append(b)
58
+ scores.append(s)
59
+
60
+ # record the max token length
61
+ max_token_length = max(len(t) for t in tokens)
62
+
63
+ # write to a binary file
64
+ # the tokenizer.bin file is the same as .model file, but .bin
65
+ tokenizer_bin = self.model_path.replace('.model', '.bin')
66
+ with open(tokenizer_bin, 'wb') as f:
67
+ f.write(struct.pack("I", max_token_length))
68
+ for bytes, score in zip(tokens, scores):
69
+ f.write(struct.pack("fI", score, len(bytes)))
70
+ f.write(bytes)
71
+
72
+ if __name__ == "__main__":
73
+ parser = argparse.ArgumentParser()
74
+ parser.add_argument("-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer ")
75
+ args = parser.parse_args()
76
+
77
+ t = Tokenizer(args.tokenizer_model)
78
+ t.export()
train.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This training script can be run both on a single gpu in debug mode,
3
+ and also in a larger training run with distributed data parallel (ddp).
4
+
5
+ To run on a single GPU small debug run, example:
6
+ $ python -m train.py --compile=False --eval_iters=10 --batch_size=8
7
+
8
+ To run with DDP on 4 gpus on 1 node, example:
9
+ $ torchrun --standalone --nproc_per_node=4 train.py
10
+
11
+ To run with DDP on 4 gpus across 2 nodes, example:
12
+ - Run on the first (master) node with example IP 123.456.123.456:
13
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
14
+ - Run on the worker node:
15
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
16
+ (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
17
+ """
18
+
19
+ import math
20
+ import os
21
+ import time
22
+ from contextlib import nullcontext
23
+ from datetime import datetime
24
+ from functools import partial
25
+
26
+ import torch
27
+ from model import Transformer, ModelArgs
28
+ from torch.distributed import destroy_process_group, init_process_group
29
+ from torch.nn.parallel import DistributedDataParallel as DDP
30
+
31
+ from tinystories import Task
32
+ from export import model_export
33
+
34
+ # -----------------------------------------------------------------------------
35
+ # I/O
36
+ out_dir = "out"
37
+ eval_interval = 2000
38
+ log_interval = 1
39
+ eval_iters = 100
40
+ eval_only = False # if True, script exits right after the first eval
41
+ always_save_checkpoint = False # if True, always save a checkpoint after each eval
42
+ init_from = "scratch" # 'scratch' or 'resume'
43
+ # wandb logging
44
+ wandb_log = False # disabled by default
45
+ wandb_project = "llamac"
46
+ wandb_run_name = "run" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
47
+ # data
48
+ batch_size = 128 # if gradient_accumulation_steps > 1, this is the micro-batch size
49
+ max_seq_len = 256
50
+ vocab_source = "llama2" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
51
+ vocab_size = 32000 # the Llama 2 tokenizer has 32K tokens
52
+ # model
53
+ dim = 288
54
+ n_layers = 6
55
+ n_heads = 6
56
+ n_kv_heads = 6
57
+ multiple_of = 32
58
+ dropout = 0.0
59
+ # adamw optimizer
60
+ gradient_accumulation_steps = 4 # used to simulate larger batch sizes
61
+ learning_rate = 5e-4 # max learning rate
62
+ max_iters = 100000 # total number of training iterations
63
+ weight_decay = 1e-1
64
+ beta1 = 0.9
65
+ beta2 = 0.95
66
+ grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
67
+ # learning rate decay settings
68
+ decay_lr = True # whether to decay the learning rate
69
+ warmup_iters = 1000 # how many steps to warm up for
70
+ # system
71
+ device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
72
+ dtype = "bfloat16" # float32|bfloat16|float16
73
+ compile = True # use PyTorch 2.0 to compile the model to be faster
74
+ # -----------------------------------------------------------------------------
75
+ config_keys = [
76
+ k
77
+ for k, v in globals().items()
78
+ if not k.startswith("_") and isinstance(v, (int, float, bool, str))
79
+ ]
80
+ exec(open("configurator.py").read()) # overrides from command line or config file
81
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
82
+ # -----------------------------------------------------------------------------
83
+
84
+ # fixing some hyperparams to sensible defaults
85
+ lr_decay_iters = max_iters # should be ~= max_iters per Chinchilla
86
+ min_lr = 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
87
+
88
+ # validating checks
89
+ assert vocab_source in ["llama2", "custom"]
90
+ assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"
91
+
92
+ # various inits, derived attributes, I/O setup
93
+ ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
94
+ if ddp:
95
+ init_process_group(backend="nccl")
96
+ ddp_rank = int(os.environ["RANK"])
97
+ ddp_local_rank = int(os.environ["LOCAL_RANK"])
98
+ ddp_world_size = int(os.environ["WORLD_SIZE"])
99
+ device = f"cuda:{ddp_local_rank}"
100
+ torch.cuda.set_device(device)
101
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
102
+ seed_offset = ddp_rank # each process gets a different seed
103
+ # world_size number of processes will be training simultaneously, so we can scale
104
+ # down the desired gradient accumulation iterations per process proportionally
105
+ assert gradient_accumulation_steps % ddp_world_size == 0
106
+ gradient_accumulation_steps //= ddp_world_size
107
+ else:
108
+ # if not ddp, we are running on a single gpu, and one process
109
+ master_process = True
110
+ seed_offset = 0
111
+ ddp_world_size = 1
112
+ tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
113
+ if master_process:
114
+ print(f"tokens per iteration will be: {tokens_per_iter:,}")
115
+ print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len")
116
+
117
+ if master_process:
118
+ os.makedirs(out_dir, exist_ok=True)
119
+ torch.manual_seed(1337 + seed_offset)
120
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
121
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
122
+ device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
123
+ # note: float16 data type will automatically use a GradScaler
124
+ ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
125
+ ctx = (
126
+ nullcontext()
127
+ if device_type == "cpu"
128
+ else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
129
+ )
130
+
131
+ # task-specific setup
132
+ iter_batches = partial(
133
+ Task.iter_batches,
134
+ batch_size=batch_size,
135
+ max_seq_len=max_seq_len,
136
+ vocab_size=vocab_size,
137
+ vocab_source=vocab_source,
138
+ device=device,
139
+ num_workers=0,
140
+ )
141
+
142
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
143
+ iter_num = 0
144
+ best_val_loss = 1e9
145
+
146
+ # model init
147
+ model_args = dict(
148
+ dim=dim,
149
+ n_layers=n_layers,
150
+ n_heads=n_heads,
151
+ n_kv_heads=n_kv_heads,
152
+ vocab_size=vocab_size,
153
+ multiple_of=multiple_of,
154
+ max_seq_len=max_seq_len,
155
+ dropout=dropout,
156
+ ) # start with model_args from command line
157
+ if init_from == "scratch":
158
+ # init a new model from scratch
159
+ print("Initializing a new model from scratch")
160
+ gptconf = ModelArgs(**model_args)
161
+ model = Transformer(gptconf)
162
+ elif init_from == "resume":
163
+ print(f"Resuming training from {out_dir}")
164
+ # resume training from a checkpoint.
165
+ ckpt_path = os.path.join(out_dir, "ckpt.pt")
166
+ checkpoint = torch.load(ckpt_path, map_location=device)
167
+ checkpoint_model_args = checkpoint["model_args"]
168
+ # force these config attributes to be equal otherwise we can't even resume training
169
+ # the rest of the attributes (e.g. dropout) can stay as desired from command line
170
+ for k in ["dim", "n_layers", "n_heads", "n_kv_heads", "vocab_size", "multiple_of", "max_seq_len"]:
171
+ model_args[k] = checkpoint_model_args[k]
172
+ # create the model
173
+ gptconf = ModelArgs(**model_args)
174
+ model = Transformer(gptconf)
175
+ state_dict = checkpoint["model"]
176
+ # fix the keys of the state dictionary :(
177
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
178
+ unwanted_prefix = "_orig_mod."
179
+ for k, v in list(state_dict.items()):
180
+ if k.startswith(unwanted_prefix):
181
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
182
+ model.load_state_dict(state_dict)
183
+ iter_num = checkpoint["iter_num"]
184
+ best_val_loss = checkpoint["best_val_loss"]
185
+ model.to(device)
186
+
187
+ # initialize a GradScaler. If enabled=False scaler is a no-op
188
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
189
+
190
+ # optimizer
191
+ optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
192
+ if init_from == "resume" and "optimizer" in checkpoint:
193
+ optimizer.load_state_dict(checkpoint["optimizer"])
194
+ checkpoint = None # free up memory
195
+
196
+ # compile the model
197
+ if compile:
198
+ print("compiling the model... (takes a ~minute)")
199
+ unoptimized_model = model
200
+ model = torch.compile(model) # requires PyTorch 2.0
201
+
202
+ # wrap model into DDP container
203
+ if ddp:
204
+ # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
205
+ # construction time since NCCL does not support `ComplexFloat`
206
+ prefix = "_orig_mod." if compile else ""
207
+ model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
208
+ model = DDP(model, device_ids=[ddp_local_rank])
209
+
210
+ # helps estimate an arbitrarily accurate loss over either split using many batches
211
+ @torch.no_grad()
212
+ def estimate_loss():
213
+ out = {}
214
+ model.eval()
215
+ for split in ["train", "val"]:
216
+ batch_iter = iter_batches(split=split)
217
+ losses = torch.zeros(eval_iters) # keep on CPU
218
+ for k in range(eval_iters):
219
+ X, Y = next(batch_iter)
220
+ with ctx:
221
+ logits = model(X, Y)
222
+ loss = raw_model.last_loss
223
+ losses[k] = loss.item()
224
+ out[split] = losses.mean()
225
+ model.train()
226
+ return out
227
+
228
+ # learning rate decay scheduler (cosine with warmup)
229
+ def get_lr(it):
230
+ # 1) linear warmup for warmup_iters steps
231
+ if it < warmup_iters:
232
+ return learning_rate * it / warmup_iters
233
+ # 2) if it > lr_decay_iters, return min learning rate
234
+ if it > lr_decay_iters:
235
+ return min_lr
236
+ # 3) in between, use cosine decay down to min learning rate
237
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
238
+ assert 0 <= decay_ratio <= 1
239
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
240
+ return min_lr + coeff * (learning_rate - min_lr)
241
+
242
+ # logging
243
+ if wandb_log and master_process:
244
+ import wandb
245
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
246
+
247
+ # training loop
248
+ train_batch_iter = iter_batches(split="train")
249
+ X, Y = next(train_batch_iter) # fetch the very first batch
250
+ t0 = time.time()
251
+ local_iter_num = 0 # number of iterations in the lifetime of this process
252
+ raw_model = model.module if ddp else model # unwrap DDP container if needed
253
+ running_mfu = -1.0
254
+ while True:
255
+ # determine and set the learning rate for this iteration
256
+ lr = get_lr(iter_num) if decay_lr else learning_rate
257
+ for param_group in optimizer.param_groups:
258
+ param_group["lr"] = lr
259
+
260
+ # evaluate the loss on train/val sets and write checkpoints
261
+ if iter_num % eval_interval == 0 and master_process:
262
+ losses = estimate_loss()
263
+ print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
264
+ if wandb_log:
265
+ try:
266
+ wandb.log(
267
+ {
268
+ "iter": iter_num,
269
+ "tokens": iter_num * tokens_per_iter,
270
+ "loss/train": losses["train"],
271
+ "loss/val": losses["val"],
272
+ "lr": lr,
273
+ "mfu": running_mfu * 100, # convert to percentage
274
+ }, step = iter_num
275
+ )
276
+ except Exception as e:
277
+ print(f"logging to wandb failed: {e}")
278
+ if losses["val"] < best_val_loss or always_save_checkpoint:
279
+ best_val_loss = losses["val"]
280
+ if iter_num > 0:
281
+ checkpoint = {
282
+ "model": raw_model.state_dict(),
283
+ "optimizer": optimizer.state_dict(),
284
+ "model_args": model_args,
285
+ "iter_num": iter_num,
286
+ "best_val_loss": best_val_loss,
287
+ "config": config,
288
+ }
289
+ print(f"saving checkpoint to {out_dir}")
290
+ torch.save(checkpoint, os.path.join(out_dir, "ckpt.pt"))
291
+ model_export(raw_model, os.path.join(out_dir, "model.bin"), version=0)
292
+ if iter_num == 0 and eval_only:
293
+ break
294
+
295
+ # forward backward update, with optional gradient accumulation to simulate larger batch size
296
+ # and using the GradScaler if data type is float16
297
+ for micro_step in range(gradient_accumulation_steps):
298
+ if ddp:
299
+ # in DDP training we only need to sync gradients at the last micro step.
300
+ # the official way to do this is with model.no_sync() context manager, but
301
+ # I really dislike that this bloats the code and forces us to repeat code
302
+ # looking at the source of that context manager, it just toggles this variable
303
+ model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
304
+ with ctx:
305
+ logits = model(X, Y)
306
+ loss = raw_model.last_loss
307
+ loss = loss / gradient_accumulation_steps
308
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
309
+ X, Y = next(train_batch_iter)
310
+ # backward pass, with gradient scaling if training in fp16
311
+ scaler.scale(loss).backward()
312
+ # clip the gradient
313
+ if grad_clip != 0.0:
314
+ scaler.unscale_(optimizer)
315
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
316
+ # step the optimizer and scaler if training in fp16
317
+ scaler.step(optimizer)
318
+ scaler.update()
319
+ # flush the gradients as soon as we can, no need for this memory anymore
320
+ optimizer.zero_grad(set_to_none=True)
321
+
322
+ # timing and logging
323
+ t1 = time.time()
324
+ dt = t1 - t0
325
+ t0 = t1
326
+ if iter_num % log_interval == 0 and master_process:
327
+ # get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
328
+ lossf = loss.item() * gradient_accumulation_steps
329
+ if local_iter_num >= 5: # let the training loop settle a bit
330
+ mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
331
+ running_mfu = mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
332
+ print(
333
+ f"{iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%"
334
+ )
335
+ iter_num += 1
336
+ local_iter_num += 1
337
+
338
+ # termination conditions
339
+ if iter_num > max_iters:
340
+ break
341
+
342
+ if ddp:
343
+ destroy_process_group()
win.c ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "win.h"
2
+ #include <errno.h>
3
+ #include <io.h>
4
+
5
+ #ifndef FILE_MAP_EXECUTE
6
+ #define FILE_MAP_EXECUTE 0x0020
7
+ #endif /* FILE_MAP_EXECUTE */
8
+
9
+ static int __map_mman_error(const uint32_t err, const int deferr)
10
+ {
11
+ if (err == 0)
12
+ return 0;
13
+ //TODO: implement
14
+ return err;
15
+ }
16
+
17
+ static uint32_t __map_mmap_prot_page(const int prot)
18
+ {
19
+ uint32_t protect = 0;
20
+
21
+ if (prot == PROT_NONE)
22
+ return protect;
23
+
24
+ if ((prot & PROT_EXEC) != 0)
25
+ {
26
+ protect = ((prot & PROT_WRITE) != 0) ?
27
+ PAGE_EXECUTE_READWRITE : PAGE_EXECUTE_READ;
28
+ }
29
+ else
30
+ {
31
+ protect = ((prot & PROT_WRITE) != 0) ?
32
+ PAGE_READWRITE : PAGE_READONLY;
33
+ }
34
+
35
+ return protect;
36
+ }
37
+
38
+ static uint32_t __map_mmap_prot_file(const int prot)
39
+ {
40
+ uint32_t desiredAccess = 0;
41
+
42
+ if (prot == PROT_NONE)
43
+ return desiredAccess;
44
+
45
+ if ((prot & PROT_READ) != 0)
46
+ desiredAccess |= FILE_MAP_READ;
47
+ if ((prot & PROT_WRITE) != 0)
48
+ desiredAccess |= FILE_MAP_WRITE;
49
+ if ((prot & PROT_EXEC) != 0)
50
+ desiredAccess |= FILE_MAP_EXECUTE;
51
+
52
+ return desiredAccess;
53
+ }
54
+
55
+ void* mmap(void *addr, size_t len, int prot, int flags, int fildes, ssize_t off)
56
+ {
57
+ HANDLE fm, h;
58
+ void * map = MAP_FAILED;
59
+
60
+ #ifdef _MSC_VER
61
+ #pragma warning(push)
62
+ #pragma warning(disable: 4293)
63
+ #endif
64
+
65
+ const uint32_t dwFileOffsetLow = (uint32_t)(off & 0xFFFFFFFFL);
66
+ const uint32_t dwFileOffsetHigh = (uint32_t)((off >> 32) & 0xFFFFFFFFL);
67
+ const uint32_t protect = __map_mmap_prot_page(prot);
68
+ const uint32_t desiredAccess = __map_mmap_prot_file(prot);
69
+
70
+ const ssize_t maxSize = off + (ssize_t)len;
71
+
72
+ const uint32_t dwMaxSizeLow = (uint32_t)(maxSize & 0xFFFFFFFFL);
73
+ const uint32_t dwMaxSizeHigh = (uint32_t)((maxSize >> 32) & 0xFFFFFFFFL);
74
+
75
+ #ifdef _MSC_VER
76
+ #pragma warning(pop)
77
+ #endif
78
+
79
+ errno = 0;
80
+
81
+ if (len == 0
82
+ /* Unsupported flag combinations */
83
+ || (flags & MAP_FIXED) != 0
84
+ /* Usupported protection combinations */
85
+ || prot == PROT_EXEC)
86
+ {
87
+ errno = EINVAL;
88
+ return MAP_FAILED;
89
+ }
90
+
91
+ h = ((flags & MAP_ANONYMOUS) == 0) ?
92
+ (HANDLE)_get_osfhandle(fildes) : INVALID_HANDLE_VALUE;
93
+
94
+ if ((flags & MAP_ANONYMOUS) == 0 && h == INVALID_HANDLE_VALUE)
95
+ {
96
+ errno = EBADF;
97
+ return MAP_FAILED;
98
+ }
99
+
100
+ fm = CreateFileMapping(h, NULL, protect, dwMaxSizeHigh, dwMaxSizeLow, NULL);
101
+
102
+ if (fm == NULL)
103
+ {
104
+ errno = __map_mman_error(GetLastError(), EPERM);
105
+ return MAP_FAILED;
106
+ }
107
+
108
+ map = MapViewOfFile(fm, desiredAccess, dwFileOffsetHigh, dwFileOffsetLow, len);
109
+
110
+ CloseHandle(fm);
111
+
112
+ if (map == NULL)
113
+ {
114
+ errno = __map_mman_error(GetLastError(), EPERM);
115
+ return MAP_FAILED;
116
+ }
117
+
118
+ return map;
119
+ }
120
+
121
+ int munmap(void *addr, size_t len)
122
+ {
123
+ if (UnmapViewOfFile(addr))
124
+ return 0;
125
+
126
+ errno = __map_mman_error(GetLastError(), EPERM);
127
+
128
+ return -1;
129
+ }
130
+
131
+ int mprotect(void *addr, size_t len, int prot)
132
+ {
133
+ uint32_t newProtect = __map_mmap_prot_page(prot);
134
+ uint32_t oldProtect = 0;
135
+
136
+ if (VirtualProtect(addr, len, newProtect, &oldProtect))
137
+ return 0;
138
+
139
+ errno = __map_mman_error(GetLastError(), EPERM);
140
+
141
+ return -1;
142
+ }
143
+
144
+ int msync(void *addr, size_t len, int flags)
145
+ {
146
+ if (FlushViewOfFile(addr, len))
147
+ return 0;
148
+
149
+ errno = __map_mman_error(GetLastError(), EPERM);
150
+
151
+ return -1;
152
+ }
153
+
154
+ int mlock(const void *addr, size_t len)
155
+ {
156
+ if (VirtualLock((LPVOID)addr, len))
157
+ return 0;
158
+
159
+ errno = __map_mman_error(GetLastError(), EPERM);
160
+
161
+ return -1;
162
+ }
163
+
164
+ int munlock(const void *addr, size_t len)
165
+ {
166
+ if (VirtualUnlock((LPVOID)addr, len))
167
+ return 0;
168
+
169
+ errno = __map_mman_error(GetLastError(), EPERM);
170
+
171
+ return -1;
172
+ }
173
+
174
+ // Portable clock_gettime function for Windows
175
+ int clock_gettime(int clk_id, struct timespec *tp) {
176
+ uint32_t ticks = GetTickCount();
177
+ tp->tv_sec = ticks / 1000;
178
+ tp->tv_nsec = (ticks % 1000) * 1000000;
179
+ return 0;
180
+ }
win.h ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef _WIN_H_
2
+ #define _WIN_H_
3
+
4
+ #define WIN32_LEAN_AND_MEAN // Exclude rarely-used stuff from Windows headers
5
+ #include <windows.h>
6
+ #include <time.h>
7
+ #include <stdint.h>
8
+
9
+ #define ssize_t int64_t
10
+ #define ftell _ftelli64
11
+
12
+ // Below code is originally from mman-win32
13
+ //
14
+ /*
15
+ * sys/mman.h
16
+ * mman-win32
17
+ */
18
+
19
+ #ifndef _WIN32_WINNT // Allow use of features specific to Windows XP or later.
20
+ #define _WIN32_WINNT 0x0501 // Change this to the appropriate value to target other versions of Windows.
21
+ #endif
22
+
23
+ /* All the headers include this file. */
24
+ #ifndef _MSC_VER
25
+ #include <_mingw.h>
26
+ #endif
27
+
28
+ #include <sys/types.h>
29
+
30
+ #ifdef __cplusplus
31
+ extern "C" {
32
+ #endif
33
+
34
+ #define PROT_NONE 0
35
+ #define PROT_READ 1
36
+ #define PROT_WRITE 2
37
+ #define PROT_EXEC 4
38
+
39
+ #define MAP_FILE 0
40
+ #define MAP_SHARED 1
41
+ #define MAP_PRIVATE 2
42
+ #define MAP_TYPE 0xf
43
+ #define MAP_FIXED 0x10
44
+ #define MAP_ANONYMOUS 0x20
45
+ #define MAP_ANON MAP_ANONYMOUS
46
+
47
+ #define MAP_FAILED ((void *)-1)
48
+
49
+ /* Flags for msync. */
50
+ #define MS_ASYNC 1
51
+ #define MS_SYNC 2
52
+ #define MS_INVALIDATE 4
53
+
54
+ /* Flags for portable clock_gettime call. */
55
+ #define CLOCK_REALTIME 0
56
+
57
+ void* mmap(void *addr, size_t len, int prot, int flags, int fildes, ssize_t off);
58
+ int munmap(void *addr, size_t len);
59
+ int mprotect(void *addr, size_t len, int prot);
60
+ int msync(void *addr, size_t len, int flags);
61
+ int mlock(const void *addr, size_t len);
62
+ int munlock(const void *addr, size_t len);
63
+ int clock_gettime(int clk_id, struct timespec *tp);
64
+
65
+ #ifdef __cplusplus
66
+ };
67
+ #endif
68
+
69
+ #endif /* _WIN_H_ */