par-meta commited on
Commit
7044771
·
unverified ·
1 Parent(s): 7622d28

This includes fixes that make checkpointing and reloading work correctly. (#35)

Browse files

It also batches in a first set of changes for fixing eval code

Summary:

Test Plan:

apps/main/lingua_train.py CHANGED
@@ -544,7 +544,7 @@ def train(args: TrainArgs):
544
  if args.eval is not None and every_n_steps(
545
  train_state, args.checkpoint.eval.every, acc_step=0
546
  ):
547
- from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
548
 
549
  eval_args = dataclass_from_dict(EvalArgs, args.eval)
550
 
 
544
  if args.eval is not None and every_n_steps(
545
  train_state, args.checkpoint.eval.every, acc_step=0
546
  ):
547
+ from bytelatent.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
548
 
549
  eval_args = dataclass_from_dict(EvalArgs, args.eval)
550
 
bytelatent/args.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any
5
 
6
  import numpy as np
7
  import yaml
 
8
  from pydantic import BaseModel, ConfigDict
9
 
10
  from bytelatent.checkpoint import CheckpointArgs
@@ -39,6 +40,19 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
39
  return np.random.default_rng((seed, rank, world_size)).bit_generator.state
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def distribute_data_to_rank(
43
  *,
44
  dataset_path: str,
@@ -71,6 +85,22 @@ def distribute_data_to_rank(
71
  return rank_to_arrow_iterator_params[rank]
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  class DataloaderArgs(BaseModel):
75
  model_config = ConfigDict(extra="forbid")
76
  s3_profile: str | None = None
@@ -168,6 +198,58 @@ class DataloaderArgs(BaseModel):
168
  return packing_iterator
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  class TrainArgs(BaseModel):
172
  model_config = ConfigDict(extra="forbid")
173
  name: str = "lingua"
@@ -186,6 +268,9 @@ class TrainArgs(BaseModel):
186
 
187
  # Nb optimizer steps to take
188
  steps: int = 1000
 
 
 
189
 
190
  data: DataloaderArgs = DataloaderArgs()
191
  optim: OptimArgs = OptimArgs()
@@ -203,7 +288,7 @@ class TrainArgs(BaseModel):
203
 
204
  # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
205
  async_eval_gpus: int | None = None
206
- eval: Any | None = None
207
  eval_on_gpus: int | None = None
208
 
209
  def dump_to_yaml_file(
 
5
 
6
  import numpy as np
7
  import yaml
8
+ from omegaconf import OmegaConf
9
  from pydantic import BaseModel, ConfigDict
10
 
11
  from bytelatent.checkpoint import CheckpointArgs
 
40
  return np.random.default_rng((seed, rank, world_size)).bit_generator.state
41
 
42
 
43
+ def parse_args(args_cls):
44
+ cli_args = OmegaConf.from_cli()
45
+ file_cfg = OmegaConf.load(cli_args.config)
46
+ # We remove 'config' attribute from config as the underlying DataClass does not have it
47
+ del cli_args.config
48
+
49
+ default_cfg = OmegaConf.create(args_cls().model_dump())
50
+ cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
51
+ cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
52
+ pydantic_args = args_cls.model_validate(cfg)
53
+ return pydantic_args
54
+
55
+
56
  def distribute_data_to_rank(
57
  *,
58
  dataset_path: str,
 
85
  return rank_to_arrow_iterator_params[rank]
86
 
87
 
88
+ class PackedCausalTransformerGeneratorArgs(BaseModel):
89
+ model_config = ConfigDict(extra="forbid")
90
+ temperature: float = 0.0
91
+ top_p: float | None = None
92
+ top_k: float | None = None
93
+ max_gen_len: int = 512 # Maximum number of tokens to generate
94
+ max_tokens: int = 1024 # Maximum number of tokens that can go through the model
95
+ max_prompt_len: int | None = None
96
+ until: list[str] = []
97
+ compile_prefilling: bool = False
98
+ reduce_generation_overhead: bool = False
99
+ show_progress: bool = False
100
+ dtype: str | None = "bf16"
101
+ device: str | None = "cuda"
102
+
103
+
104
  class DataloaderArgs(BaseModel):
105
  model_config = ConfigDict(extra="forbid")
106
  s3_profile: str | None = None
 
198
  return packing_iterator
199
 
200
 
201
+ class LMHarnessArgs(BaseModel):
202
+ model_config = ConfigDict(extra="forbid")
203
+ tasks: list[Any] | None = None
204
+ num_fewshot: int | None = None
205
+ device: str | None = None
206
+ use_cache: str | None = None
207
+ cache_requests: bool = False
208
+ rewrite_requests_cache: bool = False
209
+ delete_requests_cache: bool = False
210
+ limit: int | float | None = None
211
+ bootstrap_iters: int = 100000
212
+ check_integrity: bool = False
213
+ write_out: bool = False
214
+ log_samples: bool = True
215
+ system_instruction: str | None = None
216
+ apply_chat_template: bool | str = False
217
+ fewshot_as_multiturn: bool = False
218
+ gen_kwargs: str | None = None
219
+ verbosity: str = "INFO"
220
+ predict_only: bool = False
221
+ random_seed: int = 0
222
+ numpy_random_seed: int = 1234
223
+ torch_random_seed: int = 1234
224
+ fewshot_random_seed: int = 1234
225
+
226
+
227
+ class ValidationArgs(BaseModel):
228
+ model_config = ConfigDict(extra="forbid")
229
+ max_steps: int | None = (
230
+ None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
231
+ )
232
+ use_val_from_train_src: bool = True # Use the validation set from training sources
233
+ root_dir: str = ""
234
+ sources: list[str] = [] # Other sources to eval on
235
+
236
+
237
+ class EvalArgs(BaseModel):
238
+ model_config = ConfigDict(extra="forbid")
239
+ dump_dir: str
240
+ ckpt_dir: str
241
+ metric_log_dir: str | None = None
242
+ generator: PackedCausalTransformerGeneratorArgs = (
243
+ PackedCausalTransformerGeneratorArgs()
244
+ )
245
+
246
+ harness: LMHarnessArgs | None = LMHarnessArgs()
247
+ validation: ValidationArgs | None = ValidationArgs()
248
+
249
+ global_step: int | None = None # for in-training evaluation
250
+ s3_profile: str | None = None
251
+
252
+
253
  class TrainArgs(BaseModel):
254
  model_config = ConfigDict(extra="forbid")
255
  name: str = "lingua"
 
268
 
269
  # Nb optimizer steps to take
270
  steps: int = 1000
271
+ # If not None, halt training after this many steps,
272
+ # useful for debugging
273
+ max_steps: int | None = None
274
 
275
  data: DataloaderArgs = DataloaderArgs()
276
  optim: OptimArgs = OptimArgs()
 
288
 
289
  # If set to None, eval is run locally otherwise it launches a new job with the given number of gpus
290
  async_eval_gpus: int | None = None
291
+ eval: EvalArgs | None = None
292
  eval_on_gpus: int | None = None
293
 
294
  def dump_to_yaml_file(
bytelatent/checkpoint.py CHANGED
@@ -7,6 +7,7 @@ import re
7
  from pathlib import Path
8
  from typing import List, Optional, Tuple
9
 
 
10
  import torch
11
  import torch.distributed as dist
12
  import torch.distributed.checkpoint as dcp
@@ -21,6 +22,7 @@ from torch.distributed.checkpoint.state_dict import (
21
  set_state_dict,
22
  )
23
 
 
24
  from bytelatent.distributed import get_is_master
25
 
26
  logger = logging.getLogger("CHECKPOINT")
@@ -51,13 +53,14 @@ class CheckpointArgs(BaseModel):
51
  path: str | None = None
52
  init_ckpt_path: str | None = None
53
  continue_training_from_init: bool = False
 
54
 
55
 
56
  def _get_key_step(name: str):
57
  return int(re.findall(RE_DIGITS, name)[-1])
58
 
59
 
60
- def consolidate_checkpoints(ckpt_dir: str):
61
  """
62
  Consolidates all FSDP checkpoints in a directory to a single file
63
  Consolidate checkpoint is saved in a subdirectory of ckpt_dir
@@ -102,15 +105,17 @@ def load_from_checkpoint(
102
  dcp.load(state_dict, checkpoint_id=ckpt_dir)
103
 
104
 
 
105
  class CheckpointManager:
106
  def __init__(self, args: CheckpointArgs):
107
  self.path = args.path
 
108
  self.dump_every = args.dump
109
  self.eval_every = args.eval
110
  self.init_ckpt_path = args.init_ckpt_path
111
  self.continue_training_from_init = args.continue_training_from_init
112
 
113
- assert os.path.exists(
114
  self.path
115
  ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
116
 
 
7
  from pathlib import Path
8
  from typing import List, Optional, Tuple
9
 
10
+ import fsspec
11
  import torch
12
  import torch.distributed as dist
13
  import torch.distributed.checkpoint as dcp
 
22
  set_state_dict,
23
  )
24
 
25
+ from bytelatent.data.file_util import get_fs
26
  from bytelatent.distributed import get_is_master
27
 
28
  logger = logging.getLogger("CHECKPOINT")
 
53
  path: str | None = None
54
  init_ckpt_path: str | None = None
55
  continue_training_from_init: bool = False
56
+ s3_profile: str | None = None
57
 
58
 
59
  def _get_key_step(name: str):
60
  return int(re.findall(RE_DIGITS, name)[-1])
61
 
62
 
63
+ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
64
  """
65
  Consolidates all FSDP checkpoints in a directory to a single file
66
  Consolidate checkpoint is saved in a subdirectory of ckpt_dir
 
105
  dcp.load(state_dict, checkpoint_id=ckpt_dir)
106
 
107
 
108
+ # TODO: Rewrite the file operations here to use fsspec to enable s3 writing.
109
  class CheckpointManager:
110
  def __init__(self, args: CheckpointArgs):
111
  self.path = args.path
112
+ self.fs = get_fs(self.path, s3_profile=args.s3_profile)
113
  self.dump_every = args.dump
114
  self.eval_every = args.eval
115
  self.init_ckpt_path = args.init_ckpt_path
116
  self.continue_training_from_init = args.continue_training_from_init
117
 
118
+ assert self.fs.exists(
119
  self.path
120
  ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
121
 
bytelatent/configs/debug.yaml CHANGED
@@ -98,11 +98,4 @@ logging:
98
  freq: 10
99
 
100
  eval_on_gpus: 8
101
- eval:
102
- dataset_dir: /checkpoint/amaia/codegen/datasets/eval
103
- tasks: boolq,hellaswag,nq,piqa,siqa,tqa,winogrande,obqa,arc_easy,arc_challenge,race.middle,race.high,gsm8k,math,bbh,copa,human_eval_plus,mbpp,mmlu
104
- generator:
105
- max_tokens: 65536
106
- dtype: bf16
107
-
108
- mp_size: 1
 
98
  freq: 10
99
 
100
  eval_on_gpus: 8
101
+ eval: null
 
 
 
 
 
 
 
bytelatent/configs/entropy_model.yaml CHANGED
@@ -72,11 +72,4 @@ logging:
72
  freq: 10
73
 
74
  eval_on_gpus: 8
75
- eval:
76
- dataset_dir: ???
77
- tasks: ???
78
- generator:
79
- max_tokens: 65536
80
- dtype: bf16
81
-
82
- mp_size: 1
 
72
  freq: 10
73
 
74
  eval_on_gpus: 8
75
+ eval: null
 
 
 
 
 
 
 
bytelatent/data/data_types.py CHANGED
@@ -40,16 +40,6 @@ class BltPackTokensState(BaseModel):
40
  n_views: int = 2
41
 
42
 
43
- class DataLoaderState(BaseModel):
44
- model_config = ConfigDict(extra="forbid")
45
- multi_choice_state: MultiChoiceState
46
- pack_tokens_state: BltPackTokensState
47
- prefetch_state: PrefetchState
48
-
49
-
50
- BltIterator = Iterator[tuple[BltExample, DataLoaderState]]
51
-
52
-
53
  class BltSequence(BaseModel):
54
  tokens: list[int]
55
  mask: list[bool]
 
40
  n_views: int = 2
41
 
42
 
 
 
 
 
 
 
 
 
 
 
43
  class BltSequence(BaseModel):
44
  tokens: list[int]
45
  mask: list[bool]
bytelatent/data/iterators/multiprocess_iterator.py CHANGED
@@ -128,6 +128,13 @@ class MultiprocessIterator(StatefulIterator):
128
  self.producer = None
129
  self.stop_iterating_event = None
130
  self.state_dumped_event = None
 
 
 
 
 
 
 
131
 
132
  def get_state(self) -> MultiprocessIteratorState:
133
  """
@@ -135,6 +142,10 @@ class MultiprocessIterator(StatefulIterator):
135
  to halt the background process and allow it to write the state to the main loop
136
  in order to not lose data
137
  """
 
 
 
 
138
  if self.producer is None:
139
  serialized_prefetch_buffer = json.dumps(
140
  [b.to_python_dict() for b in self.prefetch_buffer]
@@ -187,6 +198,10 @@ class MultiprocessIterator(StatefulIterator):
187
  )
188
 
189
  def create_iter(self):
 
 
 
 
190
  logging.info("Main thread: Creating MP iterator")
191
  # First yield from the stored prefetch buffer.
192
  if self.prefetch_buffer is not None:
 
128
  self.producer = None
129
  self.stop_iterating_event = None
130
  self.state_dumped_event = None
131
+ self.force_shutdown = False
132
+
133
+ def shutdown(self):
134
+ if self.producer is not None:
135
+ # This properly shuts things down
136
+ self.producer.kill()
137
+ self.force_shutdown = True
138
 
139
  def get_state(self) -> MultiprocessIteratorState:
140
  """
 
142
  to halt the background process and allow it to write the state to the main loop
143
  in order to not lose data
144
  """
145
+ if self.force_shutdown:
146
+ raise ValueError(
147
+ "State will be invalid if shutdown was forced before state persisted."
148
+ )
149
  if self.producer is None:
150
  serialized_prefetch_buffer = json.dumps(
151
  [b.to_python_dict() for b in self.prefetch_buffer]
 
198
  )
199
 
200
  def create_iter(self):
201
+ if self.force_shutdown:
202
+ raise ValueError(
203
+ "Iterator may be invalid if shutdown was forced before state persisted."
204
+ )
205
  logging.info("Main thread: Creating MP iterator")
206
  # First yield from the stored prefetch buffer.
207
  if self.prefetch_buffer is not None:
{apps/main → bytelatent}/eval.py RENAMED
@@ -4,20 +4,20 @@ import json
4
  import logging
5
  import os
6
  from collections import defaultdict
7
- from dataclasses import asdict, dataclass, field
8
  from datetime import datetime
9
  from pathlib import Path
10
- from typing import Any, List, Optional, Tuple, Union
11
 
12
  import torch
13
- from lingua.args import dump_config
14
- from lingua.data import init_choice_state, setup_sources
15
  from lm_eval import simple_evaluate
16
  from lm_eval.api.instance import Instance
17
  from lm_eval.api.model import LM
18
  from omegaconf import OmegaConf
 
19
 
 
20
  from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
 
21
  from bytelatent.distributed import (
22
  DistributedArgs,
23
  dist_mean_dict,
@@ -25,72 +25,17 @@ from bytelatent.distributed import (
25
  get_world_size,
26
  setup_torch_distributed,
27
  )
28
- from bytelatent.transformer import LMTransformer, LMTransformerArgs
29
-
30
- from apps.main.generate import (
31
  PackedCausalTransformerGenerator,
32
- PackedCausalTransformerGeneratorArgs,
33
  load_consolidated_model_and_tokenizer,
34
  )
 
35
 
36
  EVAL_FOLDER_NAME = "{:010d}"
37
 
38
  logger = logging.getLogger()
39
 
40
 
41
- @dataclass
42
- class LMHarnessArgs:
43
- tasks: Optional[List[Any]] = None
44
- num_fewshot: Optional[int] = None
45
- device: Optional[str] = None
46
- use_cache: Optional[str] = None
47
- cache_requests: bool = False
48
- rewrite_requests_cache: bool = False
49
- delete_requests_cache: bool = False
50
- limit: Optional[Union[int, float]] = None
51
- bootstrap_iters: int = 100000
52
- check_integrity: bool = False
53
- write_out: bool = False
54
- log_samples: bool = True
55
- system_instruction: Optional[str] = None
56
- apply_chat_template: Union[bool, str] = False
57
- fewshot_as_multiturn: bool = False
58
- gen_kwargs: Optional[str] = None
59
- verbosity: str = "INFO"
60
- predict_only: bool = False
61
- random_seed: int = 0
62
- numpy_random_seed: int = 1234
63
- torch_random_seed: int = 1234
64
- fewshot_random_seed: int = 1234
65
-
66
-
67
- @dataclass
68
- class ValidationArgs:
69
- max_steps: Optional[int] = (
70
- None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
71
- )
72
- use_val_from_train_src: bool = True # Use the validation set from training sources
73
- root_dir: str = ""
74
- sources: List[str] = field(default_factory=list) # Other sources to eval on
75
-
76
-
77
- @dataclass
78
- class EvalArgs:
79
- name: str = "evals"
80
- dump_dir: Optional[str] = None
81
- metric_log_dir: Optional[str] = None
82
- ckpt_dir: str = ""
83
- generator: PackedCausalTransformerGeneratorArgs = field(
84
- default_factory=PackedCausalTransformerGeneratorArgs
85
- )
86
- harness: Optional[LMHarnessArgs] = field(default_factory=LMHarnessArgs)
87
- validation: Optional[ValidationArgs] = field(default_factory=ValidationArgs)
88
-
89
- wandb: Optional[Any] = None
90
-
91
- global_step: Optional[int] = None # for in-training evaluation
92
-
93
-
94
  def all_dicts_same(dict_list):
95
  if not dict_list: # Check if the list is empty
96
  return True
@@ -120,7 +65,7 @@ class EvalHarnessLM(LM):
120
  self._world_size = get_world_size()
121
  self.device = generator.device
122
 
123
- def generate_until(self, requests: List[Instance]) -> List[str]:
124
  prompts, gen_args = zip(*[req.args for req in requests])
125
  assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
126
  gen_args = gen_args[0]
@@ -141,7 +86,7 @@ class EvalHarnessLM(LM):
141
  filtered_gen.append(g)
142
  return filtered_gen
143
 
144
- def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
145
  prompts, continuations = zip(*[req.args for req in requests])
146
  inputs = [req.args[0] + req.args[1] for req in requests]
147
  max_gen_len = self.generator.max_gen_len
@@ -158,7 +103,7 @@ class EvalHarnessLM(LM):
158
  self.generator.max_gen_len = max_gen_len
159
  return results
160
 
161
- def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
162
  prompts = [req.args[0] for req in requests]
163
  max_gen_len = self.generator.max_gen_len
164
  # We temporarily lower max gen len
@@ -232,68 +177,73 @@ def eval_on_val(generator, val_args: ValidationArgs, train_cfg):
232
  return all_val_metrics
233
 
234
 
235
- def launch_eval(cfg: EvalArgs):
236
  if not torch.distributed.is_initialized():
237
  setup_torch_distributed(DistributedArgs())
 
 
238
  if (
239
- Path(cfg.ckpt_dir).exists()
240
- and (Path(cfg.ckpt_dir) / "params.json").exists()
241
- and next(Path(cfg.ckpt_dir).glob("*.pth"), None) is not None
242
  ):
243
- consolidate_path = Path(cfg.ckpt_dir)
244
  else:
245
- consolidate_path = Path(cfg.ckpt_dir) / CONSOLIDATE_FOLDER
246
- if not consolidate_path.exists() and get_global_rank() == 0:
247
- consolidate_path = consolidate_checkpoints(cfg.ckpt_dir)
248
 
249
- Path(cfg.dump_dir).mkdir(parents=True, exist_ok=True)
250
- dump_config(cfg, Path(cfg.dump_dir) / "config.yaml", log_config=False)
 
251
 
252
- consolidate_path = str(consolidate_path)
253
  torch.distributed.barrier()
254
  logger.info("Loading model")
 
 
255
  model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
256
  consolidate_path,
257
- model_cls=LMTransformer,
258
- model_args_cls=LMTransformerArgs,
259
  )
260
  logger.info("Model loaded")
261
  model.eval()
262
- generator = PackedCausalTransformerGenerator(cfg.generator, model, tokenizer)
263
 
264
  wrap = EvalHarnessLM(generator)
265
- results = simple_evaluate(wrap, **asdict(cfg.harness))
 
266
  val_results = None
267
- if cfg.validation:
268
- val_results = eval_on_val(generator, cfg.validation, train_cfg)
269
  if get_global_rank() == 0:
270
- with open(Path(cfg.dump_dir) / "results.json", "w") as f:
271
  f.write(json.dumps(results))
272
  logger.info(f"All evaluation results: {results['results']}")
273
  if val_results is not None:
274
- with open(Path(cfg.dump_dir) / "validation.json", "w") as f:
275
  f.write(json.dumps(val_results))
276
  logger.info(f"All validation results: {val_results}")
277
- if cfg.metric_log_dir and get_global_rank() == 0:
278
- metric_log_path = Path(cfg.metric_log_dir) / "metrics.eval.jsonl"
279
 
280
  logger.info(f"Writing metric logs to {metric_log_path}")
281
  timestamp = {
282
  "created_at": datetime.utcnow().isoformat(),
283
  }
284
- if cfg.global_step is not None:
285
- timestamp["global_step"] = cfg.global_step
286
  print(
287
  json.dumps(timestamp | results["results"]),
288
- file=open(metric_log_path, mode="a"),
289
  flush=True,
290
  )
291
 
292
- val_log_path = Path(cfg.metric_log_dir) / "metrics.validation.jsonl"
 
 
293
  if val_results is not None:
294
  print(
295
  json.dumps(timestamp | val_results),
296
- file=open(val_log_path, mode="a"),
297
  flush=True,
298
  )
299
 
@@ -301,53 +251,8 @@ def launch_eval(cfg: EvalArgs):
301
 
302
 
303
  def main():
304
- """
305
- The command line interface here uses OmegaConf https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#from-command-line-arguments
306
- This accepts arguments as a dot list
307
- So if the dataclass looks like
308
-
309
- @dataclass
310
- class DummyArgs:
311
- name: str
312
- model: LMTransformerArgsgs
313
-
314
- @dataclass
315
- class LMTransformerArgsgs:
316
- dim: int
317
-
318
- Then you can pass model.dim=32 to change values in LMTransformerArgsgs
319
- or just name=tictac for top level attributes.
320
-
321
- The behavior here is as follows:
322
- 1. We instantiate EvalArgs with its default values
323
- 2. We override those default values with the ones in the provided config file
324
- 3. We override the result with the additional arguments provided through command line
325
-
326
- For example, if the config is the following
327
-
328
- model:
329
- dim: 128
330
- n_layers: 4
331
-
332
- and you call eval.py with eval.py model.dim=64
333
-
334
- Then the final TrainArgs will have
335
-
336
- model:
337
- dim: 64
338
- n_layers: 4
339
-
340
- Plus all the default values in EvalArgs dataclass.
341
- """
342
- cli_args = OmegaConf.from_cli()
343
- file_cfg = OmegaConf.load(cli_args.config)
344
- # We remove 'config' attribute from config as the underlying DataClass does not have it
345
- del cli_args.config
346
-
347
- default_cfg = OmegaConf.structured(EvalArgs())
348
- cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
349
- cfg = OmegaConf.to_object(cfg)
350
- launch_eval(cfg)
351
 
352
 
353
  if __name__ == "__main__":
 
4
  import logging
5
  import os
6
  from collections import defaultdict
 
7
  from datetime import datetime
8
  from pathlib import Path
9
+ from typing import Any
10
 
11
  import torch
 
 
12
  from lm_eval import simple_evaluate
13
  from lm_eval.api.instance import Instance
14
  from lm_eval.api.model import LM
15
  from omegaconf import OmegaConf
16
+ from pydantic import BaseModel, ConfigDict
17
 
18
+ from bytelatent.args import EvalArgs, ValidationArgs, parse_args
19
  from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
20
+ from bytelatent.data.file_util import get_fs
21
  from bytelatent.distributed import (
22
  DistributedArgs,
23
  dist_mean_dict,
 
25
  get_world_size,
26
  setup_torch_distributed,
27
  )
28
+ from bytelatent.generate import (
 
 
29
  PackedCausalTransformerGenerator,
 
30
  load_consolidated_model_and_tokenizer,
31
  )
32
+ from bytelatent.transformer import LMTransformer, LMTransformerArgs
33
 
34
  EVAL_FOLDER_NAME = "{:010d}"
35
 
36
  logger = logging.getLogger()
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def all_dicts_same(dict_list):
40
  if not dict_list: # Check if the list is empty
41
  return True
 
65
  self._world_size = get_world_size()
66
  self.device = generator.device
67
 
68
+ def generate_until(self, requests: list[Instance]) -> list[str]:
69
  prompts, gen_args = zip(*[req.args for req in requests])
70
  assert all_dicts_same(gen_args), "Doesn't support different gen args for now"
71
  gen_args = gen_args[0]
 
86
  filtered_gen.append(g)
87
  return filtered_gen
88
 
89
+ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
90
  prompts, continuations = zip(*[req.args for req in requests])
91
  inputs = [req.args[0] + req.args[1] for req in requests]
92
  max_gen_len = self.generator.max_gen_len
 
103
  self.generator.max_gen_len = max_gen_len
104
  return results
105
 
106
+ def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
107
  prompts = [req.args[0] for req in requests]
108
  max_gen_len = self.generator.max_gen_len
109
  # We temporarily lower max gen len
 
177
  return all_val_metrics
178
 
179
 
180
+ def launch_eval(eval_args: EvalArgs):
181
  if not torch.distributed.is_initialized():
182
  setup_torch_distributed(DistributedArgs())
183
+
184
+ fs = get_fs(eval_args.ckpt_dir, s3_profile=eval_args.s3_profile)
185
  if (
186
+ fs.exists(eval_args.ckpt_dir)
187
+ and fs.exists(os.path.join(eval_args.ckpt_dir, "params.json"))
188
+ and len(fs.glob(os.path.join(eval_args.ckpt_dir, "*.pth"))) != 0
189
  ):
190
+ consolidate_path = eval_args.ckpt_dir
191
  else:
192
+ consolidate_path = os.path.join(eval_args.ckpt_dir, CONSOLIDATE_FOLDER)
193
+ if not fs.exists(consolidate_path) and get_global_rank() == 0:
194
+ consolidate_path = consolidate_checkpoints(eval_args.ckpt_dir)
195
 
196
+ fs.mkdirs(eval_args.dump_dir, exist_ok=True)
197
+ with fs.open(os.path.join(eval_args.dump_dir, "config.yaml"), "w") as f:
198
+ f.write(eval_args.model_dump_json())
199
 
 
200
  torch.distributed.barrier()
201
  logger.info("Loading model")
202
+ # TODO: Make this general so that it works with either
203
+ # LMTransformer or Blt, similar with args
204
  model, tokenizer, train_cfg = load_consolidated_model_and_tokenizer(
205
  consolidate_path,
 
 
206
  )
207
  logger.info("Model loaded")
208
  model.eval()
209
+ generator = PackedCausalTransformerGenerator(eval_args.generator, model, tokenizer)
210
 
211
  wrap = EvalHarnessLM(generator)
212
+ # Redo
213
+ results = simple_evaluate(wrap, eval_args.harness.model_dump())
214
  val_results = None
215
+ if eval_args.validation:
216
+ val_results = eval_on_val(generator, eval_args.validation, train_cfg)
217
  if get_global_rank() == 0:
218
+ with fs.open(os.path.join(eval_args.dump_dir, "results.json"), "w") as f:
219
  f.write(json.dumps(results))
220
  logger.info(f"All evaluation results: {results['results']}")
221
  if val_results is not None:
222
+ with fs.open(os.path.join(eval_args.dump_dir, "validation.json"), "w") as f:
223
  f.write(json.dumps(val_results))
224
  logger.info(f"All validation results: {val_results}")
225
+ if eval_args.metric_log_dir and get_global_rank() == 0:
226
+ metric_log_path = os.path.join(eval_args.metric_log_dir, "metrics.eval.jsonl")
227
 
228
  logger.info(f"Writing metric logs to {metric_log_path}")
229
  timestamp = {
230
  "created_at": datetime.utcnow().isoformat(),
231
  }
232
+ if eval_args.global_step is not None:
233
+ timestamp["global_step"] = eval_args.global_step
234
  print(
235
  json.dumps(timestamp | results["results"]),
236
+ file=fs.open(metric_log_path, mode="a"),
237
  flush=True,
238
  )
239
 
240
+ val_log_path = os.path.join(
241
+ eval_args.metric_log_dir, "metrics.validation.jsonl"
242
+ )
243
  if val_results is not None:
244
  print(
245
  json.dumps(timestamp | val_results),
246
+ file=fs.open(val_log_path, mode="a"),
247
  flush=True,
248
  )
249
 
 
251
 
252
 
253
  def main():
254
+ eval_args = parse_args(EvalArgs)
255
+ launch_eval(eval_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
 
258
  if __name__ == "__main__":
{apps/main → bytelatent}/generate.py RENAMED
@@ -1,20 +1,16 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
 
3
  import time
4
- from dataclasses import dataclass, field
5
- from pathlib import Path
6
- from typing import List, Optional
7
 
8
  import torch
9
- from lingua.args import dataclass_from_dict
10
- from lingua.tokenizers.abstract_tokenizer import Tokenizer
11
- from lingua.tokenizers.build_tokenizer import build_tokenizer
12
  from omegaconf import OmegaConf
13
  from torch import nn
14
  from torch.nn import functional as F
15
  from torch.nn.attention.flex_attention import create_block_mask
16
  from tqdm import tqdm
17
 
 
18
  from bytelatent.base_transformer import (
19
  Attention,
20
  causal_mask,
@@ -23,7 +19,10 @@ from bytelatent.base_transformer import (
23
  lengths_to_start_ids,
24
  )
25
  from bytelatent.checkpoint import CONSOLIDATE_NAME
26
- from bytelatent.transformer import LMTransformer, LMTransformerArgs
 
 
 
27
 
28
 
29
  def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
@@ -62,7 +61,7 @@ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None):
62
  return next_token.view(shape[:-1])
63
 
64
 
65
- def pack_prompts(prompts: List[int]):
66
  res = []
67
  lengths = []
68
  for i, p in enumerate(prompts):
@@ -120,22 +119,6 @@ class KVCache(nn.Module):
120
  return self.k_cache, self.v_cache
121
 
122
 
123
- @dataclass
124
- class PackedCausalTransformerGeneratorArgs:
125
- temperature: float = 0.0
126
- top_p: Optional[float] = None
127
- top_k: Optional[float] = None
128
- max_gen_len: int = 512 # Maximum number of tokens to generate
129
- max_tokens: int = 1024 # Maximum number of tokens that can go through the model
130
- max_prompt_len: Optional[int] = None
131
- until: List[str] = field(default_factory=list)
132
- compile_prefilling: bool = False
133
- reduce_generation_overhead: bool = False
134
- show_progress: bool = False
135
- dtype: Optional[str] = "bf16"
136
- device: Optional[str] = "cuda"
137
-
138
-
139
  class PackedCausalTransformerGenerator:
140
  def __init__(
141
  self,
@@ -401,25 +384,29 @@ class PackedCausalTransformerGenerator:
401
 
402
  def load_consolidated_model_and_tokenizer(
403
  consolidated_path,
404
- model_cls=LMTransformer,
405
- model_args_cls=LMTransformerArgs,
406
  ):
407
- ckpt_path = Path(consolidated_path)
408
- config = ckpt_path / "params.json"
409
- config = OmegaConf.load(config)
 
 
 
 
 
 
 
 
410
 
411
  param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
412
- config.distributed.model_dtype
413
  ]
414
- model_args = dataclass_from_dict(model_args_cls, config.model, strict=False)
415
- tokenizer = build_tokenizer(config.data.tokenizer.name, config.data.tokenizer.path)
416
- model = model_cls(model_args)
417
- st_dict = torch.load(ckpt_path / CONSOLIDATE_NAME, weights_only=True)
418
  model.load_state_dict(st_dict["model"])
419
  model = model.cuda().eval()
420
  for param in model.parameters():
421
  param.data = param.data.to(dtype=param_dtype)
422
- return model, tokenizer, config
423
 
424
 
425
  def main():
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
+ import os
4
  import time
 
 
 
5
 
6
  import torch
 
 
 
7
  from omegaconf import OmegaConf
8
  from torch import nn
9
  from torch.nn import functional as F
10
  from torch.nn.attention.flex_attention import create_block_mask
11
  from tqdm import tqdm
12
 
13
+ from bytelatent.args import PackedCausalTransformerGeneratorArgs, TrainArgs
14
  from bytelatent.base_transformer import (
15
  Attention,
16
  causal_mask,
 
19
  lengths_to_start_ids,
20
  )
21
  from bytelatent.checkpoint import CONSOLIDATE_NAME
22
+ from bytelatent.data.file_util import get_fs
23
+ from bytelatent.model.blt import ByteLatentTransformer
24
+ from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
25
+ from bytelatent.transformer import LMTransformer
26
 
27
 
28
  def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
 
61
  return next_token.view(shape[:-1])
62
 
63
 
64
+ def pack_prompts(prompts: list[int]):
65
  res = []
66
  lengths = []
67
  for i, p in enumerate(prompts):
 
119
  return self.k_cache, self.v_cache
120
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  class PackedCausalTransformerGenerator:
123
  def __init__(
124
  self,
 
384
 
385
  def load_consolidated_model_and_tokenizer(
386
  consolidated_path,
 
 
387
  ):
388
+ train_args_path = os.path.join(consolidated_path, "params.json")
389
+ fs = get_fs(train_args_path)
390
+ with fs.open(train_args_path) as f:
391
+ train_args = TrainArgs.model_validate_json(f.read())
392
+
393
+ if train_args.train_entropy_model:
394
+ model_args = train_args.entropy_model
395
+ model = LMTransformer(model_args)
396
+ else:
397
+ model_args = train_args.model
398
+ model = ByteLatentTransformer(model_args)
399
 
400
  param_dtype = dict(fp32=torch.float32, fp16=torch.float16, bf16=torch.bfloat16)[
401
+ train_args.distributed.model_dtype
402
  ]
403
+ tokenizer = train_args.data.tokenizer_args.build()
404
+ st_dict = torch.load(consolidated_path / CONSOLIDATE_NAME, weights_only=True)
 
 
405
  model.load_state_dict(st_dict["model"])
406
  model = model.cuda().eval()
407
  for param in model.parameters():
408
  param.data = param.data.to(dtype=param_dtype)
409
+ return model, tokenizer, train_args
410
 
411
 
412
  def main():
bytelatent/train.py CHANGED
@@ -10,7 +10,7 @@ from copy import deepcopy
10
  from dataclasses import asdict, dataclass
11
  from pathlib import Path
12
  from timeit import default_timer as timer
13
- from typing import Any, Dict, Type, TypeVar
14
 
15
  import torch
16
  import torch.distributed
@@ -23,9 +23,13 @@ from torch.distributed._tensor import DTensor
23
  from torch.distributed.checkpoint.stateful import Stateful
24
  from torch.optim import lr_scheduler
25
 
26
- from bytelatent.args import TrainArgs
27
  from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
28
- from bytelatent.data.data_types import DataLoaderState
 
 
 
 
29
  from bytelatent.distributed import (
30
  check_model_value_range,
31
  clean_env,
@@ -39,6 +43,7 @@ from bytelatent.distributed import (
39
  setup_env,
40
  setup_torch_distributed,
41
  )
 
42
  from bytelatent.logger import init_logger
43
  from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
44
  from bytelatent.model.blt import ByteLatentTransformer
@@ -70,36 +75,49 @@ def flatten_dict(d, parent_key="", sep="_"):
70
  return dict(items)
71
 
72
 
73
- def dataclass_from_dict(cls: Type[T], data: dict, strict: bool = True) -> T:
74
- """
75
- Converts a dictionary to a dataclass instance, recursively for nested structures.
76
- """
77
- base = OmegaConf.structured(cls())
78
- OmegaConf.set_struct(base, strict)
79
- override = OmegaConf.create(data)
80
- return OmegaConf.to_object(OmegaConf.merge(base, override))
81
 
82
 
 
 
83
  @dataclass
84
  class TrainState(Stateful):
85
  step: int # Nb of steps taken by the optimizer
86
  acc_step: int # Nb of accumulation steps done since last optimizer step
87
  scheduler: lr_scheduler.LambdaLR
88
- data_loader_state: DataLoaderState
89
  scale: float = 1.0
 
90
 
91
- def state_dict(self) -> Dict[str, Any]:
92
  return {
93
  "step": self.step,
94
  "acc_step": self.acc_step,
95
- "data_loader_state": self.data_loader_state.dict(),
 
96
  "scheduler": self.scheduler.state_dict(),
97
  }
98
 
99
  def load_state_dict(self, state_dict):
100
  self.step = state_dict["step"]
101
  self.acc_step = state_dict["acc_step"]
102
- self.data_loader_state = DataLoaderState(**state_dict["data_loader_state"])
 
 
 
 
 
 
 
 
 
 
103
  self.scheduler.load_state_dict(state_dict["scheduler"])
104
 
105
 
@@ -345,7 +363,10 @@ def train(args: TrainArgs):
345
  nwords_since_last_log = 0
346
  time_last_log = timer()
347
  gc.collect()
348
- while train_state.step < args.steps:
 
 
 
349
  # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
350
  train_state.acc_step += 1
351
  train_state.acc_step = train_state.acc_step % args.grad_acc_steps
@@ -552,7 +573,6 @@ def train(args: TrainArgs):
552
  f" pow: {gpu_mem_stats.power_draw/1000} W"
553
  )
554
 
555
- saved = False
556
  if every_n_steps(
557
  train_state, args.checkpoint.dump.every, acc_step=0
558
  ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
@@ -567,18 +587,14 @@ def train(args: TrainArgs):
567
  if args.eval is not None and every_n_steps(
568
  train_state, args.checkpoint.eval.every, acc_step=0
569
  ):
570
- from apps.main.eval import EVAL_FOLDER_NAME, EvalArgs, launch_eval
571
-
572
- eval_args = dataclass_from_dict(EvalArgs, args.eval)
573
 
574
  eval_args.global_step = train_state.step
575
  eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
576
- eval_args.dump_dir = str(
577
- os.path.join(
578
- args.dump_dir,
579
- "evals",
580
- EVAL_FOLDER_NAME.format(train_state.step),
581
- )
582
  )
583
  eval_args.metric_log_dir = args.dump_dir
584
  if args.async_eval_gpus is None:
@@ -619,6 +635,9 @@ def train(args: TrainArgs):
619
  args,
620
  device_mesh=world_mesh,
621
  )
 
 
 
622
  gc.collect()
623
 
624
 
@@ -661,15 +680,7 @@ def main():
661
 
662
  Plus all the default values in TrainArgs dataclass.
663
  """
664
- cli_args = OmegaConf.from_cli()
665
- file_cfg = OmegaConf.load(cli_args.config)
666
- # We remove 'config' attribute from config as the underlying DataClass does not have it
667
- del cli_args.config
668
-
669
- default_cfg = OmegaConf.create(TrainArgs().model_dump())
670
- cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
671
- cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
672
- train_args = TrainArgs.model_validate(cfg)
673
  if train_args.debug_dynamo:
674
  import torch._dynamo
675
 
 
10
  from dataclasses import asdict, dataclass
11
  from pathlib import Path
12
  from timeit import default_timer as timer
13
+ from typing import Any, TypeVar
14
 
15
  import torch
16
  import torch.distributed
 
23
  from torch.distributed.checkpoint.stateful import Stateful
24
  from torch.optim import lr_scheduler
25
 
26
+ from bytelatent.args import TrainArgs, parse_args
27
  from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
28
+ from bytelatent.data.iterators.multiprocess_iterator import (
29
+ MultiprocessIterator,
30
+ MultiprocessIteratorState,
31
+ )
32
+ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
33
  from bytelatent.distributed import (
34
  check_model_value_range,
35
  clean_env,
 
43
  setup_env,
44
  setup_torch_distributed,
45
  )
46
+ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
47
  from bytelatent.logger import init_logger
48
  from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
49
  from bytelatent.model.blt import ByteLatentTransformer
 
75
  return dict(items)
76
 
77
 
78
+ def get_iterator_state_name(iterator_state):
79
+ if isinstance(iterator_state, MultiprocessIteratorState):
80
+ return "multiprocess"
81
+ elif isinstance(iterator_state, PackingIteratorState):
82
+ return "packing"
83
+ else:
84
+ raise ValueError(f"Unsupported iterator to get name from: {iterator_state}")
 
85
 
86
 
87
+ # TODO: Make this pydantic based instead of data class based
88
+ # TODO: Generalize this to any iterator state
89
  @dataclass
90
  class TrainState(Stateful):
91
  step: int # Nb of steps taken by the optimizer
92
  acc_step: int # Nb of accumulation steps done since last optimizer step
93
  scheduler: lr_scheduler.LambdaLR
94
+ data_loader_state: MultiprocessIteratorState | PackingIteratorState
95
  scale: float = 1.0
96
+ data_loader_class: str | None = None
97
 
98
+ def state_dict(self) -> dict[str, Any]:
99
  return {
100
  "step": self.step,
101
  "acc_step": self.acc_step,
102
+ "data_loader_state": self.data_loader_state.model_dump(),
103
+ "data_loader_class": get_iterator_state_name(self.data_loader_state),
104
  "scheduler": self.scheduler.state_dict(),
105
  }
106
 
107
  def load_state_dict(self, state_dict):
108
  self.step = state_dict["step"]
109
  self.acc_step = state_dict["acc_step"]
110
+ self.data_loader_class = state_dict["data_loader_class"]
111
+ if self.data_loader_class == "multiprocess":
112
+ self.data_loader_state = MultiprocessIteratorState(
113
+ **state_dict["data_loader_state"]
114
+ )
115
+ elif self.data_loader_class == "packing":
116
+ self.data_loader_state = PackingIteratorState(
117
+ **state_dict["data_loader_state"]
118
+ )
119
+ else:
120
+ raise ValueError(f"invalid data loader class: {self.data_loader_class}")
121
  self.scheduler.load_state_dict(state_dict["scheduler"])
122
 
123
 
 
363
  nwords_since_last_log = 0
364
  time_last_log = timer()
365
  gc.collect()
366
+ saved = False
367
+ while train_state.step < args.steps and (
368
+ args.max_steps is None or train_state.step < args.max_steps
369
+ ):
370
  # We constrain train_state.acc_step to be in range 0 to args.grad_acc_steps - 1
371
  train_state.acc_step += 1
372
  train_state.acc_step = train_state.acc_step % args.grad_acc_steps
 
573
  f" pow: {gpu_mem_stats.power_draw/1000} W"
574
  )
575
 
 
576
  if every_n_steps(
577
  train_state, args.checkpoint.dump.every, acc_step=0
578
  ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
 
587
  if args.eval is not None and every_n_steps(
588
  train_state, args.checkpoint.eval.every, acc_step=0
589
  ):
590
+ eval_args = args.eval
 
 
591
 
592
  eval_args.global_step = train_state.step
593
  eval_args.ckpt_dir = str(checkpoint.existing_saves[-1])
594
+ eval_args.dump_dir = os.path.join(
595
+ args.dump_dir,
596
+ "evals",
597
+ EVAL_FOLDER_NAME.format(train_state.step),
 
 
598
  )
599
  eval_args.metric_log_dir = args.dump_dir
600
  if args.async_eval_gpus is None:
 
635
  args,
636
  device_mesh=world_mesh,
637
  )
638
+ if isinstance(data_loader, MultiprocessIterator):
639
+ logger.info("Closing MP iterator before exiting")
640
+ data_loader.shutdown()
641
  gc.collect()
642
 
643
 
 
680
 
681
  Plus all the default values in TrainArgs dataclass.
682
  """
683
+ train_args = parse_args(TrainArgs)
 
 
 
 
 
 
 
 
684
  if train_args.debug_dynamo:
685
  import torch._dynamo
686