par-meta commited on
Commit
7622d28
·
unverified ·
1 Parent(s): a809259

Initial codes and scripts for training entropy model (#34)

Browse files
.gitignore CHANGED
@@ -166,3 +166,4 @@ figures/
166
  .vscode/
167
  .DS_Store
168
  internal/
 
 
166
  .vscode/
167
  .DS_Store
168
  internal/
169
+ jobs_parallel-copy/
bytelatent/args.py CHANGED
@@ -93,6 +93,8 @@ class DataloaderArgs(BaseModel):
93
  max_encoder_seq_length: int = 12288
94
  enable_byte_ngrams: bool = False
95
 
 
 
96
  tokenizer_args: TokenizerArgs = TokenizerArgs()
97
  patcher_args: PatcherArgs = PatcherArgs()
98
 
@@ -120,6 +122,7 @@ class DataloaderArgs(BaseModel):
120
  looping_iterator,
121
  patcher_args=self.patcher_args,
122
  tokenizer_args=self.tokenizer_args,
 
123
  )
124
  sequence_iterator = SequenceIterator(
125
  preprocess_iterator,
@@ -141,13 +144,19 @@ class DataloaderArgs(BaseModel):
141
  source_to_iterator=source_to_sequence_iterators,
142
  )
143
  tokenizer = self.tokenizer_args.build()
 
 
 
 
 
144
  packing_args = PackingArgs(
145
  batch_size=self.batch_size,
146
  seq_len=self.seq_len,
147
- pad_id=tokenizer.boe_id,
148
  max_length=self.max_encoder_seq_length,
149
  pad_to_max_length=self.pad_to_max_length,
150
  enable_byte_ngrams=self.enable_byte_ngrams,
 
151
  )
152
  packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
153
  if self.load_async:
@@ -180,7 +189,7 @@ class TrainArgs(BaseModel):
180
 
181
  data: DataloaderArgs = DataloaderArgs()
182
  optim: OptimArgs = OptimArgs()
183
- model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
184
  # This is only needed for training the entropy model
185
  entropy_model: LMTransformerArgs | None = None
186
  # Instead of training main model, train entropy model
 
93
  max_encoder_seq_length: int = 12288
94
  enable_byte_ngrams: bool = False
95
 
96
+ add_patches: bool = True
97
+
98
  tokenizer_args: TokenizerArgs = TokenizerArgs()
99
  patcher_args: PatcherArgs = PatcherArgs()
100
 
 
122
  looping_iterator,
123
  patcher_args=self.patcher_args,
124
  tokenizer_args=self.tokenizer_args,
125
+ add_patches=self.add_patches,
126
  )
127
  sequence_iterator = SequenceIterator(
128
  preprocess_iterator,
 
144
  source_to_iterator=source_to_sequence_iterators,
145
  )
146
  tokenizer = self.tokenizer_args.build()
147
+ if self.tokenizer_args.name == "bytes":
148
+ # TODO: Check this with Artidoro
149
+ pad_id = 0
150
+ else:
151
+ pad_id = tokenizer.boe_id
152
  packing_args = PackingArgs(
153
  batch_size=self.batch_size,
154
  seq_len=self.seq_len,
155
+ pad_id=pad_id,
156
  max_length=self.max_encoder_seq_length,
157
  pad_to_max_length=self.pad_to_max_length,
158
  enable_byte_ngrams=self.enable_byte_ngrams,
159
+ tokenizer_name=self.tokenizer_args.name,
160
  )
161
  packing_iterator = PackingIterator(sampling_iterator, packing_args=packing_args)
162
  if self.load_async:
 
189
 
190
  data: DataloaderArgs = DataloaderArgs()
191
  optim: OptimArgs = OptimArgs()
192
+ model: ByteLatentTransformerArgs | None = ByteLatentTransformerArgs()
193
  # This is only needed for training the entropy model
194
  entropy_model: LMTransformerArgs | None = None
195
  # Instead of training main model, train entropy model
bytelatent/configs/debug.yaml CHANGED
@@ -26,10 +26,9 @@ model:
26
  vocab_size: 260
27
  dim_token: 256
28
  patch_size: 6
29
- tokenization_mode: "bytes"
30
  patching_mode: "space"
31
  tie_local_encoder_decoder_logits: false
32
- data_loader_patching: true
33
  max_encoder_seq_length: 12288
34
  pad_to_max_length: true
35
  patching_threshold: 3.1439168453216553
 
26
  vocab_size: 260
27
  dim_token: 256
28
  patch_size: 6
 
29
  patching_mode: "space"
30
  tie_local_encoder_decoder_logits: false
31
+ patch_in_forward: false
32
  max_encoder_seq_length: 12288
33
  pad_to_max_length: true
34
  patching_threshold: 3.1439168453216553
bytelatent/configs/entropy_model.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Template config, need to change dump_dir, data.root_dir and tokenizer.path
2
+ # Evals can be activated by uncommenting its config
3
+ # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
4
+
5
+ dump_dir: /tmp/
6
+ name: "debug"
7
+ steps: 100_000
8
+ probe_freq: null
9
+ seed: 777
10
+ optim:
11
+ lr: 4e-04
12
+ warmup: 500
13
+ lr_min_ratio: 0.1
14
+ clip: 10.0
15
+
16
+ distributed:
17
+ fsdp_type: full_shard
18
+ model_dtype: bf16
19
+ matmul_allow_tf32: false
20
+ selective_activation_checkpointing: false
21
+ tp_size: 1
22
+
23
+ train_entropy_model: true
24
+ model: null
25
+ entropy_model:
26
+ dim: 768
27
+ n_layers: 14
28
+ n_heads: 12
29
+ max_seqlen: 8192
30
+ # vocab_size: -1
31
+ vocab_size: 260
32
+ ffn_dim_multiplier: 1.0
33
+ sliding_window: 512
34
+ attn_bias_type: "local_block_causal"
35
+ attn_impl: "xformers"
36
+
37
+ data:
38
+ s3_profile: blt
39
+ root_dir: ???
40
+ sources:
41
+ dclm_baseline_1.0: 1.0
42
+ batch_size: 2
43
+ prefetch_size: 64
44
+ # seqlen is in terms of patches and
45
+ # max_encoder_seq_length is in terms of bytes.
46
+ # For entropy model, these are the same since 1 patch=1 byte
47
+ seq_len: 8192
48
+ max_encoder_seq_length: 8192
49
+ load_async: true
50
+ preprocess_dir: ???
51
+ # We don't need patches for this model
52
+ add_patches: false
53
+ patcher_args:
54
+ # This doesn't matter since byte entropy model doesn't use patching,
55
+ # so pick the most efficient, so static
56
+ patching_mode: byte
57
+ tokenizer_args:
58
+ name: bytes
59
+
60
+ profiling:
61
+ run: false
62
+
63
+ checkpoint:
64
+ dump:
65
+ every: 500
66
+ keep: 3
67
+ eval:
68
+ every: 1000
69
+ keep: -1
70
+
71
+ logging:
72
+ freq: 10
73
+
74
+ eval_on_gpus: 8
75
+ eval:
76
+ dataset_dir: ???
77
+ tasks: ???
78
+ generator:
79
+ max_tokens: 65536
80
+ dtype: bf16
81
+
82
+ mp_size: 1
bytelatent/data/data_types.py CHANGED
@@ -53,7 +53,7 @@ BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
53
  class BltSequence(BaseModel):
54
  tokens: list[int]
55
  mask: list[bool]
56
- patch_lengths: list[int]
57
 
58
 
59
  @dataclass
 
53
  class BltSequence(BaseModel):
54
  tokens: list[int]
55
  mask: list[bool]
56
+ patch_lengths: list[int] | None
57
 
58
 
59
  @dataclass
bytelatent/data/iterators/packing_iterator.py CHANGED
@@ -17,6 +17,7 @@ class PackingArgs(BaseModel):
17
  max_length: int | None
18
  pad_to_max_length: bool
19
  enable_byte_ngrams: bool
 
20
 
21
 
22
  class PackingIteratorState(BaseModel, IteratorState):
@@ -151,6 +152,43 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
151
  )
152
 
153
  def create_iter(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  sequence_iter = self.sequence_iterator.create_iter()
155
  batch_size = self.packing_args.batch_size
156
  pad_id = self.packing_args.pad_id
@@ -168,6 +206,10 @@ class PackingIterator(StatefulIterator[Batch, PackingIteratorState]):
168
  _tokens = sequence.tokens
169
  _mask = sequence.mask
170
  _patch_lengths = sequence.patch_lengths
 
 
 
 
171
  assert len(sequence.patch_lengths) == self.packing_args.seq_len
172
  last_patch_length = 0
173
  if _patch_lengths[0] > 1:
 
17
  max_length: int | None
18
  pad_to_max_length: bool
19
  enable_byte_ngrams: bool
20
+ tokenizer_name: str
21
 
22
 
23
  class PackingIteratorState(BaseModel, IteratorState):
 
152
  )
153
 
154
  def create_iter(self):
155
+ if self.packing_args.tokenizer_name == "bytes":
156
+ return self._create_iter_from_bytes()
157
+ else:
158
+ return self._create_iter_from_patch_lengths()
159
+
160
+ def _create_iter_from_bytes(self):
161
+ sequence_iter = self.sequence_iterator.create_iter()
162
+ batch_size = self.packing_args.batch_size
163
+ pad_id = self.packing_args.pad_id
164
+ seq_len = self.packing_args.seq_len
165
+ while True:
166
+ tokens: list[list[int]] = []
167
+ masks: list[list[bool]] = []
168
+
169
+ for _ in range(self.packing_args.batch_size):
170
+ sequence = next(sequence_iter)
171
+ _tokens = sequence.tokens
172
+ _mask = sequence.mask
173
+ assert (
174
+ sequence.patch_lengths is None
175
+ ), "patch_lengths should not be used in byte packing"
176
+ tokens.append(_tokens)
177
+ masks.append(_mask)
178
+
179
+ x = np.full((batch_size, seq_len), fill_value=pad_id)
180
+ y = np.full((batch_size, seq_len), fill_value=pad_id)
181
+
182
+ for i, tok_seq in enumerate(tokens):
183
+ x[i, : len(tok_seq)] = tok_seq
184
+ y[i, : len(tok_seq) - 1] = tok_seq[1:]
185
+ batch = Batch(x=x, y=y)
186
+ assert (
187
+ batch.mask is None or np.sum(x != pad_id) == batch.mask.sum()
188
+ ), f"{np.sum(x != pad_id)} != {batch.mask.sum()}"
189
+ yield batch
190
+
191
+ def _create_iter_from_patch_lengths(self):
192
  sequence_iter = self.sequence_iterator.create_iter()
193
  batch_size = self.packing_args.batch_size
194
  pad_id = self.packing_args.pad_id
 
206
  _tokens = sequence.tokens
207
  _mask = sequence.mask
208
  _patch_lengths = sequence.patch_lengths
209
+ assert (
210
+ _patch_lengths is not None
211
+ ), "patch lengths are required for packing based on patches."
212
+ # Reminder: seq_len is in terms of patches
213
  assert len(sequence.patch_lengths) == self.packing_args.seq_len
214
  last_patch_length = 0
215
  if _patch_lengths[0] > 1:
bytelatent/data/iterators/sequence_iterator.py CHANGED
@@ -70,15 +70,22 @@ class SequenceIterator(StatefulIterator):
70
  for example in example_iter:
71
  assert example.tokens is not None
72
  assert example.mask is not None
73
- assert example.patch_lengths is not None
 
 
 
 
74
  assert len(example.tokens) != 0
75
  assert len(example.mask) != 0
76
  assert len(example.tokens) == len(example.mask)
77
- assert len(example.tokens) == sum(example.patch_lengths)
78
 
79
  tokens.extend(example.tokens)
80
  mask.extend(example.mask)
81
- patch_lengths.extend(example.patch_lengths)
 
 
 
 
82
 
83
  while len(patch_lengths) >= n_buffer_patches:
84
  if first:
@@ -115,8 +122,15 @@ class SequenceIterator(StatefulIterator):
115
  == len(seq_mask[idx])
116
  ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
117
  assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
118
- yield BltSequence(
119
- tokens=seq_tokens[idx],
120
- mask=seq_mask[idx],
121
- patch_lengths=seq_patch_lengths[idx],
122
- )
 
 
 
 
 
 
 
 
70
  for example in example_iter:
71
  assert example.tokens is not None
72
  assert example.mask is not None
73
+ if self.preprocess_iterator.add_patches:
74
+ assert example.patch_lengths is not None
75
+ assert len(example.tokens) == sum(example.patch_lengths)
76
+ else:
77
+ assert example.patch_lengths is None
78
  assert len(example.tokens) != 0
79
  assert len(example.mask) != 0
80
  assert len(example.tokens) == len(example.mask)
 
81
 
82
  tokens.extend(example.tokens)
83
  mask.extend(example.mask)
84
+ if self.preprocess_iterator.add_patches:
85
+ patch_lengths.extend(example.patch_lengths)
86
+ else:
87
+ # This lets the rest of the code work as expected and just yield byte seqs
88
+ patch_lengths.extend([1] * len(example.tokens))
89
 
90
  while len(patch_lengths) >= n_buffer_patches:
91
  if first:
 
122
  == len(seq_mask[idx])
123
  ), f"{sum(seq_patch_lengths[idx])}, {len(seq_tokens[idx])} {len(seq_mask[idx])}, idx={idx}"
124
  assert seq_patch_lengths[idx][0] > 0, f"{seq_patch_lengths[idx]}"
125
+ if self.preprocess_iterator.add_patches:
126
+ yield BltSequence(
127
+ tokens=seq_tokens[idx],
128
+ mask=seq_mask[idx],
129
+ patch_lengths=seq_patch_lengths[idx],
130
+ )
131
+ else:
132
+ yield BltSequence(
133
+ tokens=seq_tokens[idx],
134
+ mask=seq_mask[idx],
135
+ patch_lengths=None,
136
+ )
bytelatent/data/patcher.py CHANGED
@@ -22,6 +22,8 @@ class PatchingModeEnum(str, Enum):
22
  bpe = "bpe"
23
  bpe_patcher = "bpe_patcher"
24
  space = "space"
 
 
25
 
26
 
27
  class PatcherArgs(BaseModel):
@@ -34,7 +36,6 @@ class PatcherArgs(BaseModel):
34
  max_patch_length: int | None = None
35
  patch_size: float = 4.5
36
  patching_batch_size: int = 1
37
- data_loader_patching: bool = False
38
  device: str = "cuda"
39
  monotonicity: bool = False
40
  log_time: bool = False
@@ -486,7 +487,6 @@ class Patcher:
486
  self.max_patch_length = patcher_args.max_patch_length
487
  self.patch_size = patcher_args.patch_size
488
  self.patching_batch_size = patcher_args.patching_batch_size
489
- self.data_loader_patching = patcher_args.data_loader_patching
490
  self.device = patcher_args.device
491
  self.monotonicity = patcher_args.monotonicity
492
  self.log_time = patcher_args.log_time
@@ -528,7 +528,7 @@ class Patcher:
528
  seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
529
  scores = None
530
  # STATIC
531
- if self.patching_mode is None:
532
  patch_lengths = torch.zeros(
533
  (bs, math.ceil(seq_len_next_tok / self.patch_size)),
534
  dtype=tokens.dtype,
@@ -536,6 +536,10 @@ class Patcher:
536
  ).fill_(self.patch_size)
537
  if seq_len_next_tok % self.patch_size != 0:
538
  patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
 
 
 
 
539
  # ENTROPY
540
  elif self.patching_mode == PatchingModeEnum.entropy:
541
  if self.log_time:
 
22
  bpe = "bpe"
23
  bpe_patcher = "bpe_patcher"
24
  space = "space"
25
+ static = "static"
26
+ byte = "byte"
27
 
28
 
29
  class PatcherArgs(BaseModel):
 
36
  max_patch_length: int | None = None
37
  patch_size: float = 4.5
38
  patching_batch_size: int = 1
 
39
  device: str = "cuda"
40
  monotonicity: bool = False
41
  log_time: bool = False
 
487
  self.max_patch_length = patcher_args.max_patch_length
488
  self.patch_size = patcher_args.patch_size
489
  self.patching_batch_size = patcher_args.patching_batch_size
 
490
  self.device = patcher_args.device
491
  self.monotonicity = patcher_args.monotonicity
492
  self.log_time = patcher_args.log_time
 
528
  seq_len_next_tok = seq_len + 1 if include_next_token else seq_len
529
  scores = None
530
  # STATIC
531
+ if self.patching_mode == PatchingModeEnum.static:
532
  patch_lengths = torch.zeros(
533
  (bs, math.ceil(seq_len_next_tok / self.patch_size)),
534
  dtype=tokens.dtype,
 
536
  ).fill_(self.patch_size)
537
  if seq_len_next_tok % self.patch_size != 0:
538
  patch_lengths[:, -1] = seq_len_next_tok % self.patch_size
539
+ elif self.patching_mode == PatchingModeEnum.byte:
540
+ patch_lengths = torch.ones(
541
+ (bs, seq_len_next_tok), dtype=tokens.dtype, device=tokens.device
542
+ )
543
  # ENTROPY
544
  elif self.patching_mode == PatchingModeEnum.entropy:
545
  if self.log_time:
bytelatent/model/blt.py CHANGED
@@ -411,6 +411,7 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
411
  n_heads: int = 8
412
  # TODO: What is the purpose of this parameter?
413
  weight_tying: bool = False
 
414
 
415
  # Architecture and dimensions
416
  dim_token: int = 256
@@ -422,7 +423,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
422
  n_layers_local_encoder: int = 8
423
 
424
  # Tokenization and patching
425
- tokenization_mode: str = "bpe"
426
  patch_size: float | None = None
427
  patching_mode: str | None = None
428
  patching_threshold: float | None = None
@@ -430,7 +430,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
430
  monotonicity: bool = False
431
  patching_batch_size: int = 1
432
  patching_device: str = "cuda"
433
- data_loader_patching: bool = False
434
  max_patch_length: int | None = None
435
 
436
  # Encoder/Decoder configuration
@@ -856,7 +855,7 @@ class ByteLatentTransformer(nn.Module):
856
  self.output.weight = self.tok_embeddings.weight
857
 
858
  # Patcher module
859
- if not args.data_loader_patching:
860
  self.patcher = Patcher(
861
  PatcherArgs(
862
  patch_size=args.patch_size,
 
411
  n_heads: int = 8
412
  # TODO: What is the purpose of this parameter?
413
  weight_tying: bool = False
414
+ patch_in_forward: bool = False
415
 
416
  # Architecture and dimensions
417
  dim_token: int = 256
 
423
  n_layers_local_encoder: int = 8
424
 
425
  # Tokenization and patching
 
426
  patch_size: float | None = None
427
  patching_mode: str | None = None
428
  patching_threshold: float | None = None
 
430
  monotonicity: bool = False
431
  patching_batch_size: int = 1
432
  patching_device: str = "cuda"
 
433
  max_patch_length: int | None = None
434
 
435
  # Encoder/Decoder configuration
 
855
  self.output.weight = self.tok_embeddings.weight
856
 
857
  # Patcher module
858
+ if args.patch_in_forward:
859
  self.patcher = Patcher(
860
  PatcherArgs(
861
  patch_size=args.patch_size,
bytelatent/test_blt.py CHANGED
@@ -68,10 +68,9 @@ def create_args(cross_attention=False):
68
  # Additional args from command line
69
  dim_token=256,
70
  patch_size=6,
71
- tokenization_mode="bytes",
72
  patching_mode="space",
73
  tie_local_encoder_decoder_logits=False,
74
- data_loader_patching=True,
75
  max_encoder_seq_length=12288,
76
  pad_to_max_length=True,
77
  encoder_lm_loss=False,
 
68
  # Additional args from command line
69
  dim_token=256,
70
  patch_size=6,
 
71
  patching_mode="space",
72
  tie_local_encoder_decoder_logits=False,
73
+ patch_in_forward=False,
74
  max_encoder_seq_length=12288,
75
  pad_to_max_length=True,
76
  encoder_lm_loss=False,
bytelatent/train.py CHANGED
@@ -47,6 +47,7 @@ from bytelatent.probe import AutoProbeD
47
  from bytelatent.profiling import maybe_run_profiler
48
  from bytelatent.stool import StoolArgs, launch_job
49
  from bytelatent.transformer import (
 
50
  build_fsdp_grouping_plan,
51
  get_no_recompute_ops,
52
  get_num_flop_per_token,
@@ -103,10 +104,15 @@ class TrainState(Stateful):
103
 
104
 
105
  def validate_train_args(args: TrainArgs, output_size: int):
106
- if args.model.vocab_size < 0:
 
107
  logger.info(f"Setting model output size to {args.model.vocab_size}")
108
  args.model.vocab_size = output_size
109
 
 
 
 
 
110
  assert args.dump_dir, "Dump dir not set"
111
 
112
  if args.checkpoint.path is None:
@@ -147,7 +153,10 @@ def validate_train_args(args: TrainArgs, output_size: int):
147
  and args.distributed.dp_replicate == get_world_size()
148
  )
149
 
150
- args.model.max_seqlen = args.data.seq_len
 
 
 
151
 
152
  if args.distributed.tp_size == 1:
153
  logger.warning(
@@ -237,7 +246,14 @@ def train(args: TrainArgs):
237
 
238
  # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
239
  with torch.device("meta"):
240
- model = ByteLatentTransformer(args.model)
 
 
 
 
 
 
 
241
  logger.info("Model is built !")
242
 
243
  model_param_count = get_num_params(model)
@@ -247,7 +263,7 @@ def train(args: TrainArgs):
247
  world_mesh,
248
  args.model,
249
  args.distributed,
250
- fsdp_grouping_plan=build_fsdp_grouping_plan(args.model),
251
  tp_parallelize=tp_parallelize,
252
  no_recompute_ops=get_no_recompute_ops(),
253
  )
@@ -267,7 +283,7 @@ def train(args: TrainArgs):
267
  model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
268
  else:
269
  with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
270
- torch.manual_seed(args.model.seed)
271
  model.init_weights()
272
  check_model_value_range(model, range=10.0, std=1.0)
273
 
@@ -342,10 +358,17 @@ def train(args: TrainArgs):
342
  batch.x,
343
  ).cuda()
344
  batch_y = torch.from_numpy(batch.y).cuda()
345
- batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
 
 
 
346
  mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
347
 
348
- if args.model.encoder_enable_byte_ngrams and batch.ngram_ids is None:
 
 
 
 
349
  raise ValueError(
350
  "Cannot enable byte ngrams and have batch.ngram_ids be None"
351
  )
@@ -408,9 +431,12 @@ def train(args: TrainArgs):
408
  next(probe_mod.parameters()).grad is None
409
  ), "Probe model shouldn't have grads at this point"
410
 
411
- pred = model(
412
- batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
413
- )
 
 
 
414
 
415
  loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
416
 
@@ -474,9 +500,9 @@ def train(args: TrainArgs):
474
  # Use xformer's analyze profile trace to get actual measurement
475
  FLOPS = (
476
  get_num_flop_per_token(
477
- model_param_count - args.model.vocab_size * args.model.dim,
478
- args.model.n_layers,
479
- args.model.dim,
480
  args.data.seq_len,
481
  )
482
  * wps
 
47
  from bytelatent.profiling import maybe_run_profiler
48
  from bytelatent.stool import StoolArgs, launch_job
49
  from bytelatent.transformer import (
50
+ LMTransformer,
51
  build_fsdp_grouping_plan,
52
  get_no_recompute_ops,
53
  get_num_flop_per_token,
 
104
 
105
 
106
  def validate_train_args(args: TrainArgs, output_size: int):
107
+ assert args.model is not None or args.entropy_model is not None
108
+ if args.model is not None:
109
  logger.info(f"Setting model output size to {args.model.vocab_size}")
110
  args.model.vocab_size = output_size
111
 
112
+ if args.entropy_model is not None:
113
+ logger.info(f"Setting model output size to {args.entropy_model.vocab_size}")
114
+ args.entropy_model.vocab_size = output_size
115
+
116
  assert args.dump_dir, "Dump dir not set"
117
 
118
  if args.checkpoint.path is None:
 
153
  and args.distributed.dp_replicate == get_world_size()
154
  )
155
 
156
+ if args.model is not None:
157
+ args.model.max_seqlen = args.data.seq_len
158
+ if args.entropy_model is not None:
159
+ args.entropy_model.max_seqlen = args.data.seq_len
160
 
161
  if args.distributed.tp_size == 1:
162
  logger.warning(
 
246
 
247
  # Initializing Model in meta device allows us to initialize models much bigger than 1 gpu's memory
248
  with torch.device("meta"):
249
+ if args.train_entropy_model:
250
+ assert args.entropy_model is not None
251
+ model = LMTransformer(args.entropy_model)
252
+ model_args = args.entropy_model
253
+ else:
254
+ assert args.model is not None
255
+ model = ByteLatentTransformer(args.model)
256
+ model_args = args.model
257
  logger.info("Model is built !")
258
 
259
  model_param_count = get_num_params(model)
 
263
  world_mesh,
264
  args.model,
265
  args.distributed,
266
+ fsdp_grouping_plan=build_fsdp_grouping_plan(model_args),
267
  tp_parallelize=tp_parallelize,
268
  no_recompute_ops=get_no_recompute_ops(),
269
  )
 
283
  model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
284
  else:
285
  with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
286
+ torch.manual_seed(model_args.seed)
287
  model.init_weights()
288
  check_model_value_range(model, range=10.0, std=1.0)
289
 
 
358
  batch.x,
359
  ).cuda()
360
  batch_y = torch.from_numpy(batch.y).cuda()
361
+ if batch.patch_lengths is None:
362
+ batch_patch_lengths = None
363
+ else:
364
+ batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
365
  mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
366
 
367
+ if (
368
+ not args.train_entropy_model
369
+ and args.model.encoder_enable_byte_ngrams
370
+ and batch.ngram_ids is None
371
+ ):
372
  raise ValueError(
373
  "Cannot enable byte ngrams and have batch.ngram_ids be None"
374
  )
 
431
  next(probe_mod.parameters()).grad is None
432
  ), "Probe model shouldn't have grads at this point"
433
 
434
+ if args.train_entropy_model:
435
+ pred = model(batch_x)
436
+ else:
437
+ pred = model(
438
+ batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
439
+ )
440
 
441
  loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
442
 
 
500
  # Use xformer's analyze profile trace to get actual measurement
501
  FLOPS = (
502
  get_num_flop_per_token(
503
+ model_param_count - model_args.vocab_size * model_args.dim,
504
+ model_args.n_layers,
505
+ model_args.dim,
506
  args.data.seq_len,
507
  )
508
  * wps