par-meta commited on
Commit
7517ac2
·
unverified ·
1 Parent(s): 63913e4

Get evals working again. (#46)

Browse files

- PPL/validation: Works now and uses multi-gpu. For some reason 1 GPU differs from multi-GPU, can debug in a followup PR
- Generation evals likely work, but are very slow, so disabled for now


Test Plan:
```
torchrun --nproc-per-node 8 -m bytelatent.eval config=../internal-blt/configs/eval.yaml
```

bytelatent/args.py CHANGED
@@ -270,6 +270,10 @@ class EvalArgs(BaseModel):
270
  dump_dir: str | None = None
271
  ckpt_dir: str | None = None
272
  metric_log_dir: str | None = None
 
 
 
 
273
  generator: PackedCausalTransformerGeneratorArgs = (
274
  PackedCausalTransformerGeneratorArgs()
275
  )
 
270
  dump_dir: str | None = None
271
  ckpt_dir: str | None = None
272
  metric_log_dir: str | None = None
273
+
274
+ run_ppl: bool = True
275
+ run_tasks: bool = False
276
+
277
  generator: PackedCausalTransformerGeneratorArgs = (
278
  PackedCausalTransformerGeneratorArgs()
279
  )
bytelatent/distributed.py CHANGED
@@ -15,6 +15,7 @@ from functools import lru_cache, partial, reduce
15
  from itertools import chain
16
  from typing import List, Optional, Tuple, Union
17
 
 
18
  import torch
19
 
20
  # for no recompute ops
@@ -78,6 +79,40 @@ class DistributedArgs(BaseModel):
78
 
79
  spawn_method: str = "forkserver"
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  class EnvironmentArgs(BaseModel):
83
  model_config = ConfigDict(extra="forbid")
@@ -151,6 +186,13 @@ def dist_mean_dict(x):
151
  return r
152
 
153
 
 
 
 
 
 
 
 
154
  @lru_cache()
155
  def get_is_torch_run() -> bool:
156
  return os.environ.get("LOCAL_RANK") is not None
 
15
  from itertools import chain
16
  from typing import List, Optional, Tuple, Union
17
 
18
+ import numpy as np
19
  import torch
20
 
21
  # for no recompute ops
 
79
 
80
  spawn_method: str = "forkserver"
81
 
82
+ def configure_world(self):
83
+ pass
84
+ if self.dp_replicate * self.dp_shard * self.tp_size != get_world_size():
85
+ logging.info("Modifying TrainArgs distributed config")
86
+ assert get_world_size() % self.dp_shard == 0
87
+ logging.info("World size: %s", get_world_size())
88
+ logging.info(
89
+ "Existing setting: train_args.distributed.dp_shard=%s",
90
+ self.dp_shard,
91
+ )
92
+ logging.info(
93
+ "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
94
+ get_world_size() // self.dp_shard,
95
+ self.dp_replicate,
96
+ )
97
+ self.dp_replicate = get_world_size() // self.dp_shard
98
+
99
+ logging.info(
100
+ "Changing dp_replicate from %s to %s, to account for tp_size=%s",
101
+ self.dp_replicate,
102
+ self.dp_replicate // self.tp_size,
103
+ self.tp_size,
104
+ )
105
+ assert self.dp_replicate % self.tp_size == 0
106
+ self.dp_replicate = self.dp_replicate // self.tp_size
107
+
108
+ logger.warning(
109
+ f"Setting Data Parallel size to {self.dp_replicate * self.dp_shard}"
110
+ )
111
+ assert self.dp_replicate * self.dp_shard * self.tp_size == get_world_size()
112
+
113
+ if self.fsdp_type == "no_shard":
114
+ assert self.dp_shard == 1 and self.dp_replicate == get_world_size()
115
+
116
 
117
  class EnvironmentArgs(BaseModel):
118
  model_config = ConfigDict(extra="forbid")
 
186
  return r
187
 
188
 
189
+ def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
190
+ if isinstance(num, (torch.Tensor, np.ndarray)):
191
+ return num.item()
192
+ else:
193
+ return num
194
+
195
+
196
  @lru_cache()
197
  def get_is_torch_run() -> bool:
198
  return os.environ.get("LOCAL_RANK") is not None
bytelatent/eval.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import json
4
  import logging
 
5
  import os
6
  from collections import defaultdict
7
  from datetime import datetime
@@ -10,22 +11,48 @@ import torch
10
  from lm_eval import simple_evaluate
11
  from lm_eval.api.instance import Instance
12
  from lm_eval.api.model import LM
13
-
14
- from bytelatent.args import EvalArgs, ValidationArgs
 
 
 
 
 
 
 
15
  from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
16
  from bytelatent.config_parser import parse_args_to_pydantic_model
17
  from bytelatent.data.file_util import get_fs
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  from bytelatent.distributed import (
19
  DistributedArgs,
20
  dist_mean_dict,
 
 
21
  get_global_rank,
22
  get_world_size,
23
  setup_torch_distributed,
 
24
  )
25
  from bytelatent.generate import (
26
  PackedCausalTransformerGenerator,
27
  load_consolidated_model_and_tokenizer,
28
  )
 
 
 
29
 
30
  EVAL_FOLDER_NAME = "{:010d}"
31
 
@@ -113,19 +140,134 @@ class EvalHarnessLM(LM):
113
  return results
114
 
115
 
116
- def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
117
- srcs = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  for src in val_args.sources:
119
  path = os.path.join(val_args.root_dir, src)
120
- srcs[path] = 1.0
 
121
  for src in train_cfg.data.sources:
122
  path = os.path.join(train_cfg.data.root_dir, src)
123
- srcs[path] = 1.0
124
-
125
- multi_state = init_choice_state(
126
- "", srcs, 0, get_global_rank(), get_world_size(), "*.val.jsonl"
127
- )
128
- path_to_iter = setup_sources(multi_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  max_gen_len = generator.max_gen_len
131
  # We temporarily lower max gen len
@@ -133,16 +275,11 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
133
 
134
  all_val_metrics = {}
135
  for src in path_to_iter:
136
- jsonl_iterator = path_to_iter[src]
137
  texts = []
138
  logger.info(f"Running validation on {src}...")
139
- for step, (content, state) in enumerate(jsonl_iterator):
140
- if state["current_iter"] > 0 or (
141
- val_args.max_steps is not None and step >= val_args.max_steps
142
- ):
143
- break
144
- content_key = "text" if ("text" in content) else "content"
145
- texts.append(content[content_key])
146
 
147
  _, loglikelihood, _ = generator.generate(texts)
148
 
@@ -174,8 +311,18 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
174
 
175
 
176
  def launch_eval(eval_args: EvalArgs):
 
 
 
 
177
  if not torch.distributed.is_initialized():
178
- setup_torch_distributed(DistributedArgs())
 
 
 
 
 
 
179
 
180
  fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
181
  if (
@@ -187,7 +334,7 @@ def launch_eval(eval_args: EvalArgs):
187
  else:
188
  consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
189
  if not fs.exists(consolidate_path) and get_global_rank() == 0:
190
- consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
191
 
192
  fs.mkdirs(eval_args.dump_dir, exist_ok=True)
193
  with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
@@ -200,35 +347,67 @@ def launch_eval(eval_args: EvalArgs):
200
  model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
201
  consolidate_path,
202
  )
203
- logger.info("Model loaded")
204
  model.eval()
205
- generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
206
-
207
- wrap = EvalHarnessLM(generator)
208
- # Redo
209
- results = simple_evaluate(wrap, eval_args.harness.model_dump())
210
- val_results = None
211
- if eval_args.validation:
212
- val_results = eval_on_val(generator, eval_args.validation, train_cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  if get_global_rank() == 0:
214
  with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
215
  f.write(json.dumps(results))
216
- logger.info(f"All evaluation results: {results['results']}")
217
- if val_results is not None:
218
  with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
219
- f.write(json.dumps(val_results))
220
- logger.info(f"All validation results: {val_results}")
 
221
  if eval_args.metric_log_dir and get_global_rank() == 0:
222
  metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
223
 
224
  logger.info(f"Writing metric logs to {metric_log_path}")
225
- timestamp = {
226
  "created_at": datetime.utcnow().isoformat(),
227
  }
228
  if eval_args.global_step is not None:
229
  timestamp["global_step"] = eval_args.global_step
230
  print(
231
- json.dumps(timestamp | results["results"]),
232
  file=fs.open(metric_log_path, mode="a"),
233
  flush=True,
234
  )
@@ -236,18 +415,16 @@ def launch_eval(eval_args: EvalArgs):
236
  val_log_path = os.path.join(
237
  eval_args.metric_log_dir, "metrics.validation.jsonl"
238
  )
239
- if val_results is not None:
240
  print(
241
- json.dumps(timestamp | val_results),
242
  file=fs.open(val_log_path, mode="a"),
243
  flush=True,
244
  )
245
 
246
- del generator
247
-
248
 
249
  def main():
250
- eval_args = parse_args(EvalArgs)
251
  launch_eval(eval_args)
252
 
253
 
 
2
 
3
  import json
4
  import logging
5
+ import math
6
  import os
7
  from collections import defaultdict
8
  from datetime import datetime
 
11
  from lm_eval import simple_evaluate
12
  from lm_eval.api.instance import Instance
13
  from lm_eval.api.model import LM
14
+ from rich.progress import track
15
+ from torch.nn import functional as F
16
+
17
+ from bytelatent.args import (
18
+ EvalArgs,
19
+ TrainArgs,
20
+ ValidationArgs,
21
+ find_and_sanitize_chunks,
22
+ )
23
  from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
24
  from bytelatent.config_parser import parse_args_to_pydantic_model
25
  from bytelatent.data.file_util import get_fs
26
+ from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
27
+ from bytelatent.data.iterators.limit_iterator import LimitIterator
28
+ from bytelatent.data.iterators.packing_iterator import (
29
+ PackingArgs,
30
+ PackingIterator,
31
+ PackingMode,
32
+ )
33
+ from bytelatent.data.iterators.preprocess_iterator import PreprocessIterator
34
+ from bytelatent.data.iterators.sequence_iterator import (
35
+ SequenceIterator,
36
+ SequencePackingArgs,
37
+ )
38
+ from bytelatent.data.patcher import PatcherArgs, PatchingModeEnum
39
  from bytelatent.distributed import (
40
  DistributedArgs,
41
  dist_mean_dict,
42
+ dist_sum,
43
+ get_device_mesh,
44
  get_global_rank,
45
  get_world_size,
46
  setup_torch_distributed,
47
+ to_py_num,
48
  )
49
  from bytelatent.generate import (
50
  PackedCausalTransformerGenerator,
51
  load_consolidated_model_and_tokenizer,
52
  )
53
+ from bytelatent.model.blt import ByteLatentTransformer
54
+ from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
55
+ from bytelatent.transformer import LMTransformer
56
 
57
  EVAL_FOLDER_NAME = "{:010d}"
58
 
 
140
  return results
141
 
142
 
143
+ @torch.no_grad()
144
+ def eval_ppl_on_path(
145
+ *,
146
+ world_rank: int,
147
+ world_size: int,
148
+ model: LMTransformer | ByteLatentTransformer,
149
+ tokenizer_args: TokenizerArgs,
150
+ patcher_args: PatcherArgs,
151
+ add_patches: bool,
152
+ path: str,
153
+ batch_size: int,
154
+ arrow_batch_size: int,
155
+ max_n_docs: int | None,
156
+ s3_profile: str | None = None,
157
+ ):
158
+ model.eval()
159
+ tokenizer = tokenizer_args.build()
160
+ seq_len = model.get_output_seq_len()
161
+ chunks = find_and_sanitize_chunks(
162
+ path,
163
+ world_size=1,
164
+ file_pattern="*.val.jsonl",
165
+ s3_profile=s3_profile,
166
+ )
167
+ assert (
168
+ len(chunks) == 1
169
+ ), f"There should be only 1 chunk per validation file, but found: {chunks}"
170
+ chunk = chunks[0]
171
+ arrow_iterator = ArrowFileIterator(
172
+ file_path=chunk,
173
+ preprocess_dir=None,
174
+ entropy_model_name=None,
175
+ worker_id=world_rank,
176
+ num_workers=world_size,
177
+ arrow_batch_size=arrow_batch_size,
178
+ s3_profile=s3_profile,
179
+ file_format="json",
180
+ )
181
+ if max_n_docs is not None:
182
+ arrow_iterator = LimitIterator(arrow_iterator, limit=max_n_docs)
183
+ preprocess_iterator = PreprocessIterator(
184
+ arrow_iterator,
185
+ patcher_args=patcher_args,
186
+ tokenizer_args=tokenizer_args,
187
+ add_patches=add_patches,
188
+ )
189
+ sequence_iterator = SequenceIterator(
190
+ preprocess_iterator,
191
+ sequence_packing_args=SequencePackingArgs(
192
+ output_seq_len=seq_len,
193
+ # Effectively disables shuffles
194
+ buffer_size=1,
195
+ ),
196
+ rng_state=None,
197
+ )
198
+ packing_args = PackingArgs(
199
+ batch_size=batch_size,
200
+ seq_len=seq_len,
201
+ # TODO: make these seq lens worth with blt
202
+ max_length=seq_len,
203
+ pad_to_max_length=True,
204
+ enable_byte_ngrams=False,
205
+ pad_id=tokenizer.boe_id,
206
+ packing_mode=PackingMode.BYTES,
207
+ )
208
+ packing_iterator = PackingIterator(sequence_iterator, packing_args=packing_args)
209
+ total_loss = 0.0
210
+ n_bytes = 0
211
+ batch_iterator = packing_iterator.create_iter()
212
+ for batch in batch_iterator:
213
+ x = torch.from_numpy(batch.x).cuda()
214
+ y = torch.from_numpy(batch.y).cuda()
215
+ mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
216
+ if tokenizer_args.name in ["bytes", "blt"]:
217
+ n_bytes += y.numel() if mask is None else mask.sum().item()
218
+ pred = model(x)
219
+ loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
220
+ total_loss += loss.item()
221
+ else:
222
+ raise NotImplementedError()
223
+ all_n_bytes = to_py_num(dist_sum(n_bytes))
224
+ all_total_loss = to_py_num(dist_sum(total_loss))
225
+ return {
226
+ "n_bytes": all_n_bytes,
227
+ "n_bytes_gpu": n_bytes,
228
+ "loss_sum": all_total_loss,
229
+ "loss_sum_gpu": total_loss,
230
+ "loss_mean": all_total_loss / all_n_bytes,
231
+ "loss_mean_gpu": total_loss / n_bytes,
232
+ "ppl": math.exp(all_total_loss / all_n_bytes) if all_n_bytes > 0 else 0.0,
233
+ "bpb": all_total_loss / math.log(2) / all_n_bytes,
234
+ }
235
+
236
+
237
+ def eval_on_val(generator, val_args: ValidationArgs, train_cfg: TrainArgs):
238
+ srcs = []
239
  for src in val_args.sources:
240
  path = os.path.join(val_args.root_dir, src)
241
+ srcs.append(path)
242
+
243
  for src in train_cfg.data.sources:
244
  path = os.path.join(train_cfg.data.root_dir, src)
245
+ srcs.append(path)
246
+
247
+ path_to_iter = {}
248
+ for path in srcs:
249
+ chunks = find_and_sanitize_chunks(
250
+ path,
251
+ world_size=1,
252
+ file_pattern="*.val.jsonl",
253
+ s3_profile=train_cfg.data.s3_profile,
254
+ )
255
+ assert (
256
+ len(chunks) == 1
257
+ ), f"There should be only 1 chunk per validation file, but found: {chunks}"
258
+ chunk = chunks[0]
259
+ iterator = ArrowFileIterator(
260
+ dataset_files=[chunk],
261
+ file_path=None,
262
+ preprocess_dir=None,
263
+ entropy_model_name=None,
264
+ worker_id=0,
265
+ num_workers=1,
266
+ arrow_batch_size=train_cfg.data.arrow_batch_size,
267
+ s3_profile=train_cfg.data.s3_profile,
268
+ file_format="json",
269
+ )
270
+ path_to_iter[path] = iterator
271
 
272
  max_gen_len = generator.max_gen_len
273
  # We temporarily lower max gen len
 
275
 
276
  all_val_metrics = {}
277
  for src in path_to_iter:
278
+ example_iterator = path_to_iter[src].create_iter()
279
  texts = []
280
  logger.info(f"Running validation on {src}...")
281
+ for step, example in enumerate(example_iterator):
282
+ texts.append(example.text)
 
 
 
 
 
283
 
284
  _, loglikelihood, _ = generator.generate(texts)
285
 
 
311
 
312
 
313
  def launch_eval(eval_args: EvalArgs):
314
+ assert eval_args.dump_dir is not None
315
+ assert eval_args.ckpt_dir is not None
316
+ distributed_args = DistributedArgs()
317
+ distributed_args.configure_world()
318
  if not torch.distributed.is_initialized():
319
+ setup_torch_distributed(distributed_args)
320
+
321
+ world_mesh = get_device_mesh(distributed_args)
322
+ dp_mesh = world_mesh["dp_replicate"]
323
+ assert distributed_args.dp_shard == 1
324
+ world_size = dp_mesh.size()
325
+ world_rank = dp_mesh.get_local_rank()
326
 
327
  fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
328
  if (
 
334
  else:
335
  consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
336
  if not fs.exists(consolidate_path) and get_global_rank() == 0:
337
+ consolidate_path = consolidate_checkpoints(fs, eval_args.ckpt_dir)
338
 
339
  fs.mkdirs(eval_args.dump_dir, exist_ok=True)
340
  with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
 
347
  model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
348
  consolidate_path,
349
  )
 
350
  model.eval()
351
+ logger.info("Model loaded")
352
+
353
+ ppl_results = None
354
+ if eval_args.run_ppl:
355
+ assert eval_args.validation is not None
356
+ if len(eval_args.validation.sources) > 0:
357
+ ppl_results = {}
358
+ logger.info("Starting PPL evaluation on validation sets")
359
+ for source in eval_args.validation.sources:
360
+ ppl_results[source] = eval_ppl_on_path(
361
+ world_rank=world_rank,
362
+ world_size=world_size,
363
+ model=model,
364
+ tokenizer_args=train_cfg.data.tokenizer_args,
365
+ # TODO: Don't hardcode, modify based on model
366
+ patcher_args=PatcherArgs(patching_mode=PatchingModeEnum.byte),
367
+ add_patches=False,
368
+ path=os.path.join(eval_args.validation.root_dir, source),
369
+ max_n_docs=eval_args.validation.max_n_docs,
370
+ batch_size=8,
371
+ arrow_batch_size=100,
372
+ s3_profile="blt",
373
+ )
374
+
375
+ task_results = None
376
+ if eval_args.run_tasks:
377
+ assert eval_args.generator is not None
378
+ assert eval_args.harness is not None
379
+ generator = PackedCausalTransformerGenerator(
380
+ eval_args.generator, model, tokenizer
381
+ )
382
+ wrap = EvalHarnessLM(generator)
383
+ # TODO: This needs to be checked/sped up
384
+ task_results = simple_evaluate(wrap, **eval_args.harness.model_dump())
385
+
386
+ results = {"ppl": ppl_results, "tasks": task_results}
387
+ # TODO: Serial and Parallel yield slightly different number of bytes, debug this later,
388
+ # leaving this log statement here to help with that.
389
+ # logging.info("Rank: %s Results: %s", world_rank, results)
390
+
391
  if get_global_rank() == 0:
392
  with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
393
  f.write(json.dumps(results))
394
+ logger.info(f"All evaluation results: {results}")
395
+ if ppl_results is not None:
396
  with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
397
+ f.write(json.dumps(ppl_results))
398
+ logger.info(f"All validation results: {ppl_results}")
399
+
400
  if eval_args.metric_log_dir and get_global_rank() == 0:
401
  metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
402
 
403
  logger.info(f"Writing metric logs to {metric_log_path}")
404
+ timestamp: dict[str, int | str] = {
405
  "created_at": datetime.utcnow().isoformat(),
406
  }
407
  if eval_args.global_step is not None:
408
  timestamp["global_step"] = eval_args.global_step
409
  print(
410
+ json.dumps(timestamp | results),
411
  file=fs.open(metric_log_path, mode="a"),
412
  flush=True,
413
  )
 
415
  val_log_path = os.path.join(
416
  eval_args.metric_log_dir, "metrics.validation.jsonl"
417
  )
418
+ if ppl_results is not None:
419
  print(
420
+ json.dumps(timestamp | ppl_results),
421
  file=fs.open(val_log_path, mode="a"),
422
  flush=True,
423
  )
424
 
 
 
425
 
426
  def main():
427
+ eval_args = parse_args_to_pydantic_model(EvalArgs)
428
  launch_eval(eval_args)
429
 
430
 
bytelatent/generate.py CHANGED
@@ -387,8 +387,7 @@ def load_consolidated_model_and_tokenizer(
387
  ):
388
  train_args_path = os.path.join(consolidated_path, "params.json")
389
  fs = get_fs(train_args_path)
390
- with fs.open(train_args_path) as f:
391
- train_args = TrainArgs.model_validate_json(f.read())
392
 
393
  if train_args.train_entropy_model:
394
  model_args = train_args.entropy_model
@@ -401,7 +400,8 @@ def load_consolidated_model_and_tokenizer(
401
  train_args.distributed.model_dtype
402
  ]
403
  tokenizer = train_args.data.tokenizer_args.build()
404
- st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
 
405
  model.load_state_dict(st_dict["model"])
406
  model = model.cuda().eval()
407
  for param in model.parameters():
 
387
  ):
388
  train_args_path = os.path.join(consolidated_path, "params.json")
389
  fs = get_fs(train_args_path)
390
+ train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
 
391
 
392
  if train_args.train_entropy_model:
393
  model_args = train_args.entropy_model
 
400
  train_args.distributed.model_dtype
401
  ]
402
  tokenizer = train_args.data.tokenizer_args.build()
403
+ with fs.open(os.path.join(consolidated_path, CONSOLIDATE_NAME)) as f:
404
+ st_dict = torch.load(f, weights_only=True)
405
  model.load_state_dict(st_dict["model"])
406
  model = model.cuda().eval()
407
  for param in model.parameters():
bytelatent/metrics.py CHANGED
@@ -55,7 +55,7 @@ class LoggingArgs(BaseModel):
55
  class MetricLogger:
56
  def __init__(
57
  self,
58
- outdir: Path,
59
  # args: TrainArgs
60
  args: Any | None = None,
61
  fs: fsspec.AbstractFileSystem | None = None,
 
55
  class MetricLogger:
56
  def __init__(
57
  self,
58
+ outdir: str,
59
  # args: TrainArgs
60
  args: Any | None = None,
61
  fs: fsspec.AbstractFileSystem | None = None,
bytelatent/train.py CHANGED
@@ -48,6 +48,7 @@ from bytelatent.distributed import (
48
  requeue_slurm_job,
49
  setup_env,
50
  setup_torch_distributed,
 
51
  )
52
  from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
53
  from bytelatent.logger import init_logger
@@ -91,13 +92,6 @@ def get_iterator_state_name(iterator_state):
91
  raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
92
 
93
 
94
- def to_py_num(num: int | float | torch.Tensor | np.ndarray) -> int | float:
95
- if isinstance(num, (torch.Tensor, np.ndarray)):
96
- return num.item()
97
- else:
98
- return num
99
-
100
-
101
  # TODO: Make this pydantic based instead of data class based
102
  # TODO: Generalize this to any iterator state
103
  @dataclass
@@ -154,57 +148,13 @@ def validate_train_args(args: TrainArgs, output_size: int):
154
  logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
155
  args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
156
 
157
- data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
158
- for source in args.data.sources:
159
- data_path = os.path.join(args.data.root_dir, source)
160
- assert data_fs.exists(data_path), f"{data_path} doesn't exist"
161
-
162
- if (
163
- args.distributed.dp_replicate
164
- * args.distributed.dp_shard
165
- * args.distributed.tp_size
166
- != get_world_size()
167
- ):
168
- logging.info("Modifying TrainArgs distributed config")
169
- assert get_world_size() % args.distributed.dp_shard == 0
170
- logging.info("World size: %s", get_world_size())
171
- logging.info(
172
- "Existing setting: train_args.distributed.dp_shard=%s",
173
- args.distributed.dp_shard,
174
- )
175
- logging.info(
176
- "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
177
- get_world_size() // args.distributed.dp_shard,
178
- args.distributed.dp_replicate,
179
- )
180
- args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
181
-
182
- logging.info(
183
- "Changing dp_replicate from %s to %s, to account for tp_size=%s",
184
- args.distributed.dp_replicate,
185
- args.distributed.dp_replicate // args.distributed.tp_size,
186
- args.distributed.tp_size,
187
- )
188
- assert args.distributed.dp_replicate % args.distributed.tp_size == 0
189
- args.distributed.dp_replicate = (
190
- args.distributed.dp_replicate // args.distributed.tp_size
191
- )
192
-
193
- logger.warning(
194
- f"Setting Data Parallel size to {args.distributed.dp_replicate * args.distributed.dp_shard}"
195
- )
196
- assert (
197
- args.distributed.dp_replicate
198
- * args.distributed.dp_shard
199
- * args.distributed.tp_size
200
- == get_world_size()
201
- )
202
 
203
- if args.distributed.fsdp_type == "no_shard":
204
- assert (
205
- args.distributed.dp_shard == 1
206
- and args.distributed.dp_replicate == get_world_size()
207
- )
208
 
209
  if args.model is not None:
210
  args.model.max_seqlen = args.data.seq_len
@@ -243,7 +193,9 @@ def set_preemption_flag(signum, frame):
243
  preemption_flag["flag"] = True
244
 
245
 
246
- def every_n_steps(train_state, freq, acc_step=None, acc_freq=None):
 
 
247
  test = train_state.step % freq == 0
248
  if acc_step is not None:
249
  test = test and (train_state.acc_step == acc_step)
@@ -272,7 +224,7 @@ def train(args: TrainArgs):
272
  tokenizer = args.data.tokenizer_args.build()
273
  validate_train_args(
274
  args,
275
- tokenizer.n_words,
276
  )
277
  dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
278
  if get_is_master():
 
48
  requeue_slurm_job,
49
  setup_env,
50
  setup_torch_distributed,
51
+ to_py_num,
52
  )
53
  from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
54
  from bytelatent.logger import init_logger
 
92
  raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
93
 
94
 
 
 
 
 
 
 
 
95
  # TODO: Make this pydantic based instead of data class based
96
  # TODO: Generalize this to any iterator state
97
  @dataclass
 
148
  logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
149
  args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
150
 
151
+ if args.data.root_dir is not None:
152
+ data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
153
+ for source in args.data.sources:
154
+ data_path = os.path.join(args.data.root_dir, source)
155
+ assert data_fs.exists(data_path), f"{data_path} doesn't exist"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ args.distributed.configure_world()
 
 
 
 
158
 
159
  if args.model is not None:
160
  args.model.max_seqlen = args.data.seq_len
 
193
  preemption_flag["flag"] = True
194
 
195
 
196
+ def every_n_steps(train_state, freq: int, acc_step=None, acc_freq=None):
197
+ if freq < 0:
198
+ return False
199
  test = train_state.step % freq == 0
200
  if acc_step is not None:
201
  test = test and (train_state.acc_step == acc_step)
 
224
  tokenizer = args.data.tokenizer_args.build()
225
  validate_train_args(
226
  args,
227
+ tokenizer.get_vocab_size(),
228
  )
229
  dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
230
  if get_is_master():