Spaces:
Running
on
Zero
Running
on
Zero
Initial codes and scripts for training entropy model (#34)
Browse files- .gitignore +1 -0
- bytelatent/args.py +11 -2
- bytelatent/configs/debug.yaml +1 -2
- bytelatent/configs/entropy_model.yaml +82 -0
- bytelatent/data/data_types.py +1 -1
- bytelatent/data/iterators/packing_iterator.py +42 -0
- bytelatent/data/iterators/sequence_iterator.py +22 -8
- bytelatent/data/patcher.py +7 -3
- bytelatent/model/blt.py +2 -3
- bytelatent/test_blt.py +1 -2
- bytelatent/train.py +39 -13
.gitignore
CHANGED
@@ -166,3 +166,4 @@ figures/
|
|
166 |
.vscode/
|
167 |
.DS_Store
|
168 |
internal/
|
|
|
|
166 |
.vscode/
|
167 |
.DS_Store
|
168 |
internal/
|
169 |
+
jobs_parallel-copy/
|
bytelatent/args.py
CHANGED
@@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel):
|
|
93 |
max_encoder_seq_length: int = 12288
|
94 |
enable_byte_ngrams: bool = False
|
95 |
|
|
|
|
|
96 |
tokenizer_args: TokenizerArgs = TokenizerArgs()
|
97 |
patcher_args: PatcherArgs = PatcherArgs()
|
98 |
|
@@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel):
|
|
120 |
looping_iterator,
|
121 |
patcher_args=self.patcher_args,
|
122 |
tokenizer_args=self.tokenizer_args,
|
|
|
123 |
)
|
124 |
sequence_iterator = SequenceIterator(
|
125 |
preprocess_iterator,
|
@@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel):
|
|
141 |
source_to_iterator=source_to_sequence_iterators,
|
142 |
)
|
143 |
tokenizer = self.tokenizer_args.build()
|
|
|
|
|
|
|
|
|
|
|
144 |
packing_args = PackingArgs(
|
145 |
batch_size=self.batch_size,
|
146 |
seq_len=self.seq_len,
|
147 |
-
pad_id=
|
148 |
max_length=self.max_encoder_seq_length,
|
149 |
pad_to_max_length=self.pad_to_max_length,
|
150 |
enable_byte_ngrams=self.enable_byte_ngrams,
|
|
|
151 |
)
|
152 |
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
153 |
if self.load_async:
|
@@ -180,7 +189,7 @@ class TrainArgs(BaseModel):
|
|
180 |
|
181 |
data: DataloaderArgs = DataloaderArgs()
|
182 |
optim: OptimArgs = OptimArgs()
|
183 |
-
model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
|
184 |
# This is only needed for training the entropy model
|
185 |
entropy_model: LMTransformerArgs | None = None
|
186 |
# Instead of training main model, train entropy model
|
|
|
93 |
max_encoder_seq_length: int = 12288
|
94 |
enable_byte_ngrams: bool = False
|
95 |
|
96 |
+
add_patches: bool = True
|
97 |
+
|
98 |
tokenizer_args: TokenizerArgs = TokenizerArgs()
|
99 |
patcher_args: PatcherArgs = PatcherArgs()
|
100 |
|
|
|
122 |
looping_iterator,
|
123 |
patcher_args=self.patcher_args,
|
124 |
tokenizer_args=self.tokenizer_args,
|
125 |
+
add_patches=self.add_patches,
|
126 |
)
|
127 |
sequence_iterator = SequenceIterator(
|
128 |
preprocess_iterator,
|
|
|
144 |
source_to_iterator=source_to_sequence_iterators,
|
145 |
)
|
146 |
tokenizer = self.tokenizer_args.build()
|
147 |
+
if self.tokenizer_args.name == "bytes":
|
148 |
+
# TODO: Check this with Artidoro
|
149 |
+
pad_id = 0
|
150 |
+
else:
|
151 |
+
pad_id = tokenizer.boe_id
|
152 |
packing_args = PackingArgs(
|
153 |
batch_size=self.batch_size,
|
154 |
seq_len=self.seq_len,
|
155 |
+
pad_id=pad_id,
|
156 |
max_length=self.max_encoder_seq_length,
|
157 |
pad_to_max_length=self.pad_to_max_length,
|
158 |
enable_byte_ngrams=self.enable_byte_ngrams,
|
159 |
+
tokenizer_name=self.tokenizer_args.name,
|
160 |
)
|
161 |
packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
|
162 |
if self.load_async:
|
|
|
189 |
|
190 |
data: DataloaderArgs = DataloaderArgs()
|
191 |
optim: OptimArgs = OptimArgs()
|
192 |
+
model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs()
|
193 |
# This is only needed for training the entropy model
|
194 |
entropy_model: LMTransformerArgs | None = None
|
195 |
# Instead of training main model, train entropy model
|
bytelatent/configs/debug.yaml
CHANGED
@@ -26,10 +26,9 @@ model:
|
|
26 |
vocab_size: 260
|
27 |
dim_token: 256
|
28 |
patch_size: 6
|
29 |
-
tokenization_mode: "bytes"
|
30 |
patching_mode: "space"
|
31 |
tie_local_encoder_decoder_logits: false
|
32 |
-
|
33 |
max_encoder_seq_length: 12288
|
34 |
pad_to_max_length: true
|
35 |
patching_threshold: 3.1439168453216553
|
|
|
26 |
vocab_size: 260
|
27 |
dim_token: 256
|
28 |
patch_size: 6
|
|
|
29 |
patching_mode: "space"
|
30 |
tie_local_encoder_decoder_logits: false
|
31 |
+
patch_in_forward: false
|
32 |
max_encoder_seq_length: 12288
|
33 |
pad_to_max_length: true
|
34 |
patching_threshold: 3.1439168453216553
|
bytelatent/configs/entropy_model.yaml
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Template config, need to change dump_dir, data.root_dir and tokenizer.path
|
2 |
+
# Evals can be activated by uncommenting its config
|
3 |
+
# python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
|
4 |
+
|
5 |
+
dump_dir: /tmp/
|
6 |
+
name: "debug"
|
7 |
+
steps: 100_000
|
8 |
+
probe_freq: null
|
9 |
+
seed: 777
|
10 |
+
optim:
|
11 |
+
lr: 4e-04
|
12 |
+
warmup: 500
|
13 |
+
lr_min_ratio: 0.1
|
14 |
+
clip: 10.0
|
15 |
+
|
16 |
+
distributed:
|
17 |
+
fsdp_type: full_shard
|
18 |
+
model_dtype: bf16
|
19 |
+
matmul_allow_tf32: false
|
20 |
+
selective_activation_checkpointing: false
|
21 |
+
tp_size: 1
|
22 |
+
|
23 |
+
train_entropy_model: true
|
24 |
+
model: null
|
25 |
+
entropy_model:
|
26 |
+
dim: 768
|
27 |
+
n_layers: 14
|
28 |
+
n_heads: 12
|
29 |
+
max_seqlen: 8192
|
30 |
+
# vocab_size: -1
|
31 |
+
vocab_size: 260
|
32 |
+
ffn_dim_multiplier: 1.0
|
33 |
+
sliding_window: 512
|
34 |
+
attn_bias_type: "local_block_causal"
|
35 |
+
attn_impl: "xformers"
|
36 |
+
|
37 |
+
data:
|
38 |
+
s3_profile: blt
|
39 |
+
root_dir: ???
|
40 |
+
sources:
|
41 |
+
dclm_baseline_1.0: 1.0
|
42 |
+
batch_size: 2
|
43 |
+
prefetch_size: 64
|
44 |
+
# seqlen is in terms of patches and
|
45 |
+
# max_encoder_seq_length is in terms of bytes.
|
46 |
+
# For entropy model, these are the same since 1 patch=1 byte
|
47 |
+
seq_len: 8192
|
48 |
+
max_encoder_seq_length: 8192
|
49 |
+
load_async: true
|
50 |
+
preprocess_dir: ???
|
51 |
+
# We don't need patches for this model
|
52 |
+
add_patches: false
|
53 |
+
patcher_args:
|
54 |
+
# This doesn't matter since byte entropy model doesn't use patching,
|
55 |
+
# so pick the most efficient, so static
|
56 |
+
patching_mode: byte
|
57 |
+
tokenizer_args:
|
58 |
+
name: bytes
|
59 |
+
|
60 |
+
profiling:
|
61 |
+
run: false
|
62 |
+
|
63 |
+
checkpoint:
|
64 |
+
dump:
|
65 |
+
every: 500
|
66 |
+
keep: 3
|
67 |
+
eval:
|
68 |
+
every: 1000
|
69 |
+
keep: -1
|
70 |
+
|
71 |
+
logging:
|
72 |
+
freq: 10
|
73 |
+
|
74 |
+
eval_on_gpus: 8
|
75 |
+
eval:
|
76 |
+
dataset_dir: ???
|
77 |
+
tasks: ???
|
78 |
+
generator:
|
79 |
+
max_tokens: 65536
|
80 |
+
dtype: bf16
|
81 |
+
|
82 |
+
mp_size: 1
|
bytelatent/data/data_types.py
CHANGED
@@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
|
|
53 |
class BltSequence(BaseModel):
|
54 |
tokens: list[int]
|
55 |
mask: list[bool]
|
56 |
-
patch_lengths: list[int]
|
57 |
|
58 |
|
59 |
@dataclass
|
|
|
53 |
class BltSequence(BaseModel):
|
54 |
tokens: list[int]
|
55 |
mask: list[bool]
|
56 |
+
patch_lengths: list[int] | None
|
57 |
|
58 |
|
59 |
@dataclass
|
bytelatent/data/iterators/packing_iterator.py
CHANGED
@@ -17,6 +17,7 @@ class PackingArgs(BaseModel):
|
|
17 |
max_length: int | None
|
18 |
pad_to_max_length: bool
|
19 |
enable_byte_ngrams: bool
|
|
|
20 |
|
21 |
|
22 |
class PackingIteratorState(BaseModel, IteratorState):
|
@@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
|
151 |
)
|
152 |
|
153 |
def create_iter(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
sequence_iter = self.sequence_iterator.create_iter()
|
155 |
batch_size = self.packing_args.batch_size
|
156 |
pad_id = self.packing_args.pad_id
|
@@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
|
|
168 |
_tokens = sequence.tokens
|
169 |
_mask = sequence.mask
|
170 |
_patch_lengths = sequence.patch_lengths
|
|
|
|
|
|
|
|
|
171 |
assert len(sequence.patch_lengths) == self.packing_args.seq_len
|
172 |
last_patch_length = 0
|
173 |
if _patch_lengths[0] > 1:
|
|
|
17 |
max_length: int | None
|
18 |
pad_to_max_length: bool
|
19 |
enable_byte_ngrams: bool
|
20 |
+
tokenizer_name: str
|
21 |
|
22 |
|
23 |
class PackingIteratorState(BaseModel, IteratorState):
|
|
|
152 |
)
|
153 |
|
154 |
def create_iter(self):
|
155 |
+
if self.packing_args.tokenizer_name == "bytes":
|
156 |
+
return self._create_iter_from_bytes()
|
157 |
+
else:
|
158 |
+
return self._create_iter_from_patch_lengths()
|
159 |
+
|
160 |
+
def _create_iter_from_bytes(self):
|
161 |
+
sequence_iter = self.sequence_iterator.create_iter()
|
162 |
+
batch_size = self.packing_args.batch_size
|
163 |
+
pad_id = self.packing_args.pad_id
|
164 |
+
seq_len = self.packing_args.seq_len
|
165 |
+
while True:
|
166 |
+
tokens: list[list[int]] = []
|
167 |
+
masks: list[list[bool]] = []
|
168 |
+
|
169 |
+
for _ in range(self.packing_args.batch_size):
|
170 |
+
sequence = next(sequence_iter)
|
171 |
+
_tokens = sequence.tokens
|
172 |
+
_mask = sequence.mask
|
173 |
+
assert (
|
174 |
+
sequence.patch_lengths is None
|
175 |
+
), "patch_lengths should not be used in byte packing"
|
176 |
+
tokens.append(_tokens)
|
177 |
+
masks.append(_mask)
|
178 |
+
|
179 |
+
x = np.full((batch_size, seq_len), fill_value=pad_id)
|
180 |
+
y = np.full((batch_size, seq_len), fill_value=pad_id)
|
181 |
+
|
182 |
+
for i, tok_seq in enumerate(tokens):
|
183 |
+
x[i, : len(tok_seq)] = tok_seq
|
184 |
+
y[i, : len(tok_seq) - 1] = tok_seq[1:]
|
185 |
+
batch = Batch(x=x, y=y)
|
186 |
+
assert (
|
187 |
+
batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
|
188 |
+
), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
|
189 |
+
yield batch
|
190 |
+
|
191 |
+
def _create_iter_from_patch_lengths(self):
|
192 |
sequence_iter = self.sequence_iterator.create_iter()
|
193 |
batch_size = self.packing_args.batch_size
|
194 |
pad_id = self.packing_args.pad_id
|
|
|
206 |
_tokens = sequence.tokens
|
207 |
_mask = sequence.mask
|
208 |
_patch_lengths = sequence.patch_lengths
|
209 |
+
assert (
|
210 |
+
_patch_lengths is not None
|
211 |
+
), "patch lengths are required for packing based on patches."
|
212 |
+
# Reminder: seq_len is in terms of patches
|
213 |
assert len(sequence.patch_lengths) == self.packing_args.seq_len
|
214 |
last_patch_length = 0
|
215 |
if _patch_lengths[0] > 1:
|
bytelatent/data/iterators/sequence_iterator.py
CHANGED
@@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator):
|
|
70 |
for example in example_iter:
|
71 |
assert example.tokens is not None
|
72 |
assert example.mask is not None
|
73 |
-
|
|
|
|
|
|
|
|
|
74 |
assert len(example.tokens) != 0
|
75 |
assert len(example.mask) != 0
|
76 |
assert len(example.tokens) == len(example.mask)
|
77 |
-
assert len(example.tokens) == sum(example.patch_lengths)
|
78 |
|
79 |
tokens.extend(example.tokens)
|
80 |
mask.extend(example.mask)
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
|
83 |
while len(patch_lengths) >= n_buffer_patches:
|
84 |
if first:
|
@@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator):
|
|
115 |
== len(seq_mask[idx])
|
116 |
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
|
117 |
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
for example in example_iter:
|
71 |
assert example.tokens is not None
|
72 |
assert example.mask is not None
|
73 |
+
if self.preprocess_iterator.add_patches:
|
74 |
+
assert example.patch_lengths is not None
|
75 |
+
assert len(example.tokens) == sum(example.patch_lengths)
|
76 |
+
else:
|
77 |
+
assert example.patch_lengths is None
|
78 |
assert len(example.tokens) != 0
|
79 |
assert len(example.mask) != 0
|
80 |
assert len(example.tokens) == len(example.mask)
|
|
|
81 |
|
82 |
tokens.extend(example.tokens)
|
83 |
mask.extend(example.mask)
|
84 |
+
if self.preprocess_iterator.add_patches:
|
85 |
+
patch_lengths.extend(example.patch_lengths)
|
86 |
+
else:
|
87 |
+
# This lets the rest of the code work as expected and just yield byte seqs
|
88 |
+
patch_lengths.extend([1] * len(example.tokens))
|
89 |
|
90 |
while len(patch_lengths) >= n_buffer_patches:
|
91 |
if first:
|
|
|
122 |
== len(seq_mask[idx])
|
123 |
), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
|
124 |
assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
|
125 |
+
if self.preprocess_iterator.add_patches:
|
126 |
+
yield BltSequence(
|
127 |
+
tokens=seq_tokens[idx],
|
128 |
+
mask=seq_mask[idx],
|
129 |
+
patch_lengths=seq_patch_lengths[idx],
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
yield BltSequence(
|
133 |
+
tokens=seq_tokens[idx],
|
134 |
+
mask=seq_mask[idx],
|
135 |
+
patch_lengths=None,
|
136 |
+
)
|
bytelatent/data/patcher.py
CHANGED
@@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum):
|
|
22 |
bpe = "bpe"
|
23 |
bpe_patcher = "bpe_patcher"
|
24 |
space = "space"
|
|
|
|
|
25 |
|
26 |
|
27 |
class PatcherArgs(BaseModel):
|
@@ -34,7 +36,6 @@ class PatcherArgs(BaseModel):
|
|
34 |
max_patch_length: int | None = None
|
35 |
patch_size: float = 4.5
|
36 |
patching_batch_size: int = 1
|
37 |
-
data_loader_patching: bool = False
|
38 |
device: str = "cuda"
|
39 |
monotonicity: bool = False
|
40 |
log_time: bool = False
|
@@ -486,7 +487,6 @@ class Patcher:
|
|
486 |
self.max_patch_length = patcher_args.max_patch_length
|
487 |
self.patch_size = patcher_args.patch_size
|
488 |
self.patching_batch_size = patcher_args.patching_batch_size
|
489 |
-
self.data_loader_patching = patcher_args.data_loader_patching
|
490 |
self.device = patcher_args.device
|
491 |
self.monotonicity = patcher_args.monotonicity
|
492 |
self.log_time = patcher_args.log_time
|
@@ -528,7 +528,7 @@ class Patcher:
|
|
528 |
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
|
529 |
scores = None
|
530 |
# STATIC
|
531 |
-
if self.patching_mode
|
532 |
patch_lengths = torch.zeros(
|
533 |
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
|
534 |
dtype=tokens.dtype,
|
@@ -536,6 +536,10 @@ class Patcher:
|
|
536 |
).fill_(self.patch_size)
|
537 |
if seq_len_next_tok % self.patch_size != 0:
|
538 |
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
|
|
|
|
|
|
|
|
|
539 |
# ENTROPY
|
540 |
elif self.patching_mode == PatchingModeEnum.entropy:
|
541 |
if self.log_time:
|
|
|
22 |
bpe = "bpe"
|
23 |
bpe_patcher = "bpe_patcher"
|
24 |
space = "space"
|
25 |
+
static = "static"
|
26 |
+
byte = "byte"
|
27 |
|
28 |
|
29 |
class PatcherArgs(BaseModel):
|
|
|
36 |
max_patch_length: int | None = None
|
37 |
patch_size: float = 4.5
|
38 |
patching_batch_size: int = 1
|
|
|
39 |
device: str = "cuda"
|
40 |
monotonicity: bool = False
|
41 |
log_time: bool = False
|
|
|
487 |
self.max_patch_length = patcher_args.max_patch_length
|
488 |
self.patch_size = patcher_args.patch_size
|
489 |
self.patching_batch_size = patcher_args.patching_batch_size
|
|
|
490 |
self.device = patcher_args.device
|
491 |
self.monotonicity = patcher_args.monotonicity
|
492 |
self.log_time = patcher_args.log_time
|
|
|
528 |
seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
|
529 |
scores = None
|
530 |
# STATIC
|
531 |
+
if self.patching_mode == PatchingModeEnum.static:
|
532 |
patch_lengths = torch.zeros(
|
533 |
(bs, math.ceil(seq_len_next_tok / self.patch_size)),
|
534 |
dtype=tokens.dtype,
|
|
|
536 |
).fill_(self.patch_size)
|
537 |
if seq_len_next_tok % self.patch_size != 0:
|
538 |
patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
|
539 |
+
elif self.patching_mode == PatchingModeEnum.byte:
|
540 |
+
patch_lengths = torch.ones(
|
541 |
+
(bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device
|
542 |
+
)
|
543 |
# ENTROPY
|
544 |
elif self.patching_mode == PatchingModeEnum.entropy:
|
545 |
if self.log_time:
|
bytelatent/model/blt.py
CHANGED
@@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|
411 |
n_heads: int = 8
|
412 |
# TODO: What is the purpose of this parameter?
|
413 |
weight_tying: bool = False
|
|
|
414 |
|
415 |
# Architecture and dimensions
|
416 |
dim_token: int = 256
|
@@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|
422 |
n_layers_local_encoder: int = 8
|
423 |
|
424 |
# Tokenization and patching
|
425 |
-
tokenization_mode: str = "bpe"
|
426 |
patch_size: float | None = None
|
427 |
patching_mode: str | None = None
|
428 |
patching_threshold: float | None = None
|
@@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
|
|
430 |
monotonicity: bool = False
|
431 |
patching_batch_size: int = 1
|
432 |
patching_device: str = "cuda"
|
433 |
-
data_loader_patching: bool = False
|
434 |
max_patch_length: int | None = None
|
435 |
|
436 |
# Encoder/Decoder configuration
|
@@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module):
|
|
856 |
self.output.weight = self.tok_embeddings.weight
|
857 |
|
858 |
# Patcher module
|
859 |
-
if
|
860 |
self.patcher = Patcher(
|
861 |
PatcherArgs(
|
862 |
patch_size=args.patch_size,
|
|
|
411 |
n_heads: int = 8
|
412 |
# TODO: What is the purpose of this parameter?
|
413 |
weight_tying: bool = False
|
414 |
+
patch_in_forward: bool = False
|
415 |
|
416 |
# Architecture and dimensions
|
417 |
dim_token: int = 256
|
|
|
423 |
n_layers_local_encoder: int = 8
|
424 |
|
425 |
# Tokenization and patching
|
|
|
426 |
patch_size: float | None = None
|
427 |
patching_mode: str | None = None
|
428 |
patching_threshold: float | None = None
|
|
|
430 |
monotonicity: bool = False
|
431 |
patching_batch_size: int = 1
|
432 |
patching_device: str = "cuda"
|
|
|
433 |
max_patch_length: int | None = None
|
434 |
|
435 |
# Encoder/Decoder configuration
|
|
|
855 |
self.output.weight = self.tok_embeddings.weight
|
856 |
|
857 |
# Patcher module
|
858 |
+
if args.patch_in_forward:
|
859 |
self.patcher = Patcher(
|
860 |
PatcherArgs(
|
861 |
patch_size=args.patch_size,
|
bytelatent/test_blt.py
CHANGED
@@ -68,10 +68,9 @@ def create_args(cross_attention=False):
|
|
68 |
# Additional args from command line
|
69 |
dim_token=256,
|
70 |
patch_size=6,
|
71 |
-
tokenization_mode="bytes",
|
72 |
patching_mode="space",
|
73 |
tie_local_encoder_decoder_logits=False,
|
74 |
-
|
75 |
max_encoder_seq_length=12288,
|
76 |
pad_to_max_length=True,
|
77 |
encoder_lm_loss=False,
|
|
|
68 |
# Additional args from command line
|
69 |
dim_token=256,
|
70 |
patch_size=6,
|
|
|
71 |
patching_mode="space",
|
72 |
tie_local_encoder_decoder_logits=False,
|
73 |
+
patch_in_forward=False,
|
74 |
max_encoder_seq_length=12288,
|
75 |
pad_to_max_length=True,
|
76 |
encoder_lm_loss=False,
|
bytelatent/train.py
CHANGED
@@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD
|
|
47 |
from bytelatent.profiling import maybe_run_profiler
|
48 |
from bytelatent.stool import StoolArgs, launch_job
|
49 |
from bytelatent.transformer import (
|
|
|
50 |
build_fsdp_grouping_plan,
|
51 |
get_no_recompute_ops,
|
52 |
get_num_flop_per_token,
|
@@ -103,10 +104,15 @@ class TrainState(Stateful):
|
|
103 |
|
104 |
|
105 |
def validate_train_args(args: TrainArgs, output_size: int):
|
106 |
-
|
|
|
107 |
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
108 |
args.model.vocab_size = output_size
|
109 |
|
|
|
|
|
|
|
|
|
110 |
assert args.dump_dir, "Dump dir not set"
|
111 |
|
112 |
if args.checkpoint.path is None:
|
@@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
|
147 |
and args.distributed.dp_replicate == get_world_size()
|
148 |
)
|
149 |
|
150 |
-
args.model
|
|
|
|
|
|
|
151 |
|
152 |
if args.distributed.tp_size == 1:
|
153 |
logger.warning(
|
@@ -237,7 +246,14 @@ def train(args: TrainArgs):
|
|
237 |
|
238 |
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
239 |
with torch.device("meta"):
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
logger.info("Model is built !")
|
242 |
|
243 |
model_param_count = get_num_params(model)
|
@@ -247,7 +263,7 @@ def train(args: TrainArgs):
|
|
247 |
world_mesh,
|
248 |
args.model,
|
249 |
args.distributed,
|
250 |
-
fsdp_grouping_plan=build_fsdp_grouping_plan(
|
251 |
tp_parallelize=tp_parallelize,
|
252 |
no_recompute_ops=get_no_recompute_ops(),
|
253 |
)
|
@@ -267,7 +283,7 @@ def train(args: TrainArgs):
|
|
267 |
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
268 |
else:
|
269 |
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
270 |
-
torch.manual_seed(
|
271 |
model.init_weights()
|
272 |
check_model_value_range(model, range=10.0, std=1.0)
|
273 |
|
@@ -342,10 +358,17 @@ def train(args: TrainArgs):
|
|
342 |
batch.x,
|
343 |
).cuda()
|
344 |
batch_y = torch.from_numpy(batch.y).cuda()
|
345 |
-
|
|
|
|
|
|
|
346 |
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
347 |
|
348 |
-
if
|
|
|
|
|
|
|
|
|
349 |
raise ValueError(
|
350 |
"Cannot enable byte ngrams and have batch.ngram_ids be None"
|
351 |
)
|
@@ -408,9 +431,12 @@ def train(args: TrainArgs):
|
|
408 |
next(probe_mod.parameters()).grad is None
|
409 |
), "Probe model shouldn't have grads at this point"
|
410 |
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
414 |
|
415 |
loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
|
416 |
|
@@ -474,9 +500,9 @@ def train(args: TrainArgs):
|
|
474 |
# Use xformer's analyze profile trace to get actual measurement
|
475 |
FLOPS = (
|
476 |
get_num_flop_per_token(
|
477 |
-
model_param_count -
|
478 |
-
|
479 |
-
|
480 |
args.data.seq_len,
|
481 |
)
|
482 |
* wps
|
|
|
47 |
from bytelatent.profiling import maybe_run_profiler
|
48 |
from bytelatent.stool import StoolArgs, launch_job
|
49 |
from bytelatent.transformer import (
|
50 |
+
LMTransformer,
|
51 |
build_fsdp_grouping_plan,
|
52 |
get_no_recompute_ops,
|
53 |
get_num_flop_per_token,
|
|
|
104 |
|
105 |
|
106 |
def validate_train_args(args: TrainArgs, output_size: int):
|
107 |
+
assert args.model is not None or args.entropy_model is not None
|
108 |
+
if args.model is not None:
|
109 |
logger.info(f"Setting model output size to {args.model.vocab_size}")
|
110 |
args.model.vocab_size = output_size
|
111 |
|
112 |
+
if args.entropy_model is not None:
|
113 |
+
logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
|
114 |
+
args.entropy_model.vocab_size = output_size
|
115 |
+
|
116 |
assert args.dump_dir, "Dump dir not set"
|
117 |
|
118 |
if args.checkpoint.path is None:
|
|
|
153 |
and args.distributed.dp_replicate == get_world_size()
|
154 |
)
|
155 |
|
156 |
+
if args.model is not None:
|
157 |
+
args.model.max_seqlen = args.data.seq_len
|
158 |
+
if args.entropy_model is not None:
|
159 |
+
args.entropy_model.max_seqlen = args.data.seq_len
|
160 |
|
161 |
if args.distributed.tp_size == 1:
|
162 |
logger.warning(
|
|
|
246 |
|
247 |
# Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
|
248 |
with torch.device("meta"):
|
249 |
+
if args.train_entropy_model:
|
250 |
+
assert args.entropy_model is not None
|
251 |
+
model = LMTransformer(args.entropy_model)
|
252 |
+
model_args = args.entropy_model
|
253 |
+
else:
|
254 |
+
assert args.model is not None
|
255 |
+
model = ByteLatentTransformer(args.model)
|
256 |
+
model_args = args.model
|
257 |
logger.info("Model is built !")
|
258 |
|
259 |
model_param_count = get_num_params(model)
|
|
|
263 |
world_mesh,
|
264 |
args.model,
|
265 |
args.distributed,
|
266 |
+
fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
|
267 |
tp_parallelize=tp_parallelize,
|
268 |
no_recompute_ops=get_no_recompute_ops(),
|
269 |
)
|
|
|
283 |
model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
|
284 |
else:
|
285 |
with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
|
286 |
+
torch.manual_seed(model_args.seed)
|
287 |
model.init_weights()
|
288 |
check_model_value_range(model, range=10.0, std=1.0)
|
289 |
|
|
|
358 |
batch.x,
|
359 |
).cuda()
|
360 |
batch_y = torch.from_numpy(batch.y).cuda()
|
361 |
+
if batch.patch_lengths is None:
|
362 |
+
batch_patch_lengths = None
|
363 |
+
else:
|
364 |
+
batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
|
365 |
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
366 |
|
367 |
+
if (
|
368 |
+
not args.train_entropy_model
|
369 |
+
and args.model.encoder_enable_byte_ngrams
|
370 |
+
and batch.ngram_ids is None
|
371 |
+
):
|
372 |
raise ValueError(
|
373 |
"Cannot enable byte ngrams and have batch.ngram_ids be None"
|
374 |
)
|
|
|
431 |
next(probe_mod.parameters()).grad is None
|
432 |
), "Probe model shouldn't have grads at this point"
|
433 |
|
434 |
+
if args.train_entropy_model:
|
435 |
+
pred = model(batch_x)
|
436 |
+
else:
|
437 |
+
pred = model(
|
438 |
+
batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
|
439 |
+
)
|
440 |
|
441 |
loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
|
442 |
|
|
|
500 |
# Use xformer's analyze profile trace to get actual measurement
|
501 |
FLOPS = (
|
502 |
get_num_flop_per_token(
|
503 |
+
model_param_count - model_args.vocab_size * model_args.dim,
|
504 |
+
model_args.n_layers,
|
505 |
+
model_args.dim,
|
506 |
args.data.seq_len,
|
507 |
)
|
508 |
* wps
|