par-meta commited on
Commit
82ab593
·
unverified ·
1 Parent(s): 9f29e0d

Make it possible to specify multiple config files (#54)

Browse files

Summary:

Make it possible to specify multiple config files.
Parsing CLI is not a special case anymore, just uses the same config inheritance method.

Test Plan:

Test that this iterpolates in the right order via unit tests

Sample usage, loads the internal config, which references bytelatent/configs/entropy_model.yaml. The precendence order is:

- Default pydantic args
- Included configs, eg `config`
- CLI args

```
python -m bytelatent.print_config config=internal/configs/entropy_model.yaml eval=null

```


Summary:

Test Plan:

bytelatent/args.py CHANGED
@@ -5,7 +5,6 @@ from typing import Any
5
 
6
  import numpy as np
7
  import yaml
8
- from omegaconf import OmegaConf
9
  from pydantic import BaseModel, ConfigDict
10
 
11
  from bytelatent.checkpoint import CheckpointArgs
@@ -38,19 +37,6 @@ def get_rng_state(seed: int, rank: int, world_size: int) -> dict[str, Any]:
38
  return np.random.default_rng((seed, rank, world_size)).bit_generator.state
39
 
40
 
41
- def parse_args(args_cls):
42
- cli_args = OmegaConf.from_cli()
43
- file_cfg = OmegaConf.load(cli_args.config)
44
- # We remove 'config' attribute from config as the underlying DataClass does not have it
45
- del cli_args.config
46
-
47
- default_cfg = OmegaConf.create(args_cls().model_dump())
48
- cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
49
- cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
50
- pydantic_args = args_cls.model_validate(cfg)
51
- return pydantic_args
52
-
53
-
54
  TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
55
 
56
 
 
5
 
6
  import numpy as np
7
  import yaml
 
8
  from pydantic import BaseModel, ConfigDict
9
 
10
  from bytelatent.checkpoint import CheckpointArgs
 
37
  return np.random.default_rng((seed, rank, world_size)).bit_generator.state
38
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  TRAIN_DATA_FILE_PATTERN = "*.chunk.*.jsonl"
41
 
42
 
bytelatent/config_parser.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Type, TypeVar
3
+
4
+ import omegaconf
5
+ from omegaconf import DictConfig, OmegaConf
6
+ from pydantic import BaseModel
7
+
8
+
9
+ def parse_file_config(path: str) -> DictConfig:
10
+ file_cfg = OmegaConf.load(path)
11
+ if not isinstance(file_cfg, DictConfig):
12
+ raise ValueError(
13
+ f"File paths must parse to DictConfig, but it was: {type(file_cfg)}"
14
+ )
15
+ return file_cfg
16
+
17
+
18
+ def recursively_parse_config(cfg: DictConfig) -> list[DictConfig]:
19
+ if "config" not in cfg:
20
+ return [cfg]
21
+
22
+ ordered_cfgs = []
23
+ cfg = copy.deepcopy(cfg)
24
+ config_arg = cfg["config"]
25
+ del cfg["config"]
26
+ ordered_cfgs.append(cfg)
27
+
28
+ if isinstance(config_arg, str):
29
+ file_cfg = parse_file_config(config_arg)
30
+ sub_configs = recursively_parse_config(file_cfg)
31
+ ordered_cfgs = sub_configs + ordered_cfgs
32
+ elif isinstance(config_arg, omegaconf.listconfig.ListConfig):
33
+ sub_configs = []
34
+ for c in config_arg:
35
+ if not isinstance(c, str):
36
+ raise ValueError(
37
+ f'If "config" is specified, it must be either a string path or a list of string paths. It was config={config_arg}'
38
+ )
39
+ config_to_parse = parse_file_config(c)
40
+ sub_configs.extend(recursively_parse_config(config_to_parse))
41
+ ordered_cfgs = sub_configs + ordered_cfgs
42
+ else:
43
+ raise ValueError(
44
+ f'If "config" is specified, it must be either a string path or a list of string paths, it was config={config_arg}'
45
+ )
46
+ return ordered_cfgs
47
+
48
+
49
+ def parse_args_with_default(
50
+ *, default_cfg: DictConfig | None = None, cli_args: DictConfig | None = None
51
+ ):
52
+ if cli_args is None:
53
+ cli_args = OmegaConf.from_cli()
54
+ assert isinstance(
55
+ cli_args, DictConfig
56
+ ), f"CLI Args must be a DictConfig, not {type(cli_args)}"
57
+ ordered_cfgs = recursively_parse_config(cli_args)
58
+ if default_cfg is not None:
59
+ ordered_cfgs.insert(0, default_cfg)
60
+ cfg = OmegaConf.merge(*ordered_cfgs)
61
+ return OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
62
+
63
+
64
+ T = TypeVar("T", bound=BaseModel)
65
+
66
+
67
+ def parse_args_to_pydantic_model(
68
+ args_cls: Type[T], cli_args: DictConfig | None = None
69
+ ) -> T:
70
+ default_cfg = OmegaConf.create(args_cls().model_dump())
71
+ parsed_cfg = parse_args_with_default(default_cfg=default_cfg, cli_args=cli_args)
72
+ pydantic_args = args_cls.model_validate(parsed_cfg)
73
+ return pydantic_args
bytelatent/configs/debug.yaml CHANGED
@@ -56,13 +56,11 @@ model:
56
  recompute_attn: false
57
  custom_bwd: false
58
  layer_ckpt: "none"
59
- patch_only_encoder: false
60
- patch_only_decoder: false
61
  use_local_encoder_transformer: true
62
  init_use_gaussian: true
63
  init_use_depth: "current"
64
- attn_bias_type: "block_causal"
65
  attn_impl: "xformers"
 
66
  alpha_depth: "disabled"
67
  max_length: 256
68
  local_attention_window_len: 512
 
56
  recompute_attn: false
57
  custom_bwd: false
58
  layer_ckpt: "none"
 
 
59
  use_local_encoder_transformer: true
60
  init_use_gaussian: true
61
  init_use_depth: "current"
 
62
  attn_impl: "xformers"
63
+ attn_bias_type: "block_causal"
64
  alpha_depth: "disabled"
65
  max_length: 256
66
  local_attention_window_len: 512
bytelatent/configs/entropy_model.yaml CHANGED
@@ -2,9 +2,10 @@
2
  # Evals can be activated by uncommenting its config
3
  # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
4
 
5
- dump_dir: /tmp/
6
  name: "debug"
7
  steps: 100_000
 
8
  probe_freq: null
9
  seed: 777
10
  optim:
@@ -35,7 +36,6 @@ entropy_model:
35
  attn_impl: "xformers"
36
 
37
  data:
38
- s3_profile: blt
39
  root_dir: ???
40
  sources:
41
  dclm_baseline_1.0: 1.0
 
2
  # Evals can be activated by uncommenting its config
3
  # python -m launchers.stool config=apps/main/configs/debug.yaml nodes=8 account=fair_amaia_cw_codegen qos=lowest
4
 
5
+ dump_dir: /tmp/blt-entropy
6
  name: "debug"
7
  steps: 100_000
8
+ max_steps: null
9
  probe_freq: null
10
  seed: 777
11
  optim:
 
36
  attn_impl: "xformers"
37
 
38
  data:
 
39
  root_dir: ???
40
  sources:
41
  dclm_baseline_1.0: 1.0
bytelatent/eval.py CHANGED
@@ -5,18 +5,15 @@ import logging
5
  import os
6
  from collections import defaultdict
7
  from datetime import datetime
8
- from pathlib import Path
9
- from typing import Any
10
 
11
  import torch
12
  from lm_eval import simple_evaluate
13
  from lm_eval.api.instance import Instance
14
  from lm_eval.api.model import LM
15
- from omegaconf import OmegaConf
16
- from pydantic import BaseModel, ConfigDict
17
 
18
- from bytelatent.args import EvalArgs, ValidationArgs, parse_args
19
  from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
 
20
  from bytelatent.data.file_util import get_fs
21
  from bytelatent.distributed import (
22
  DistributedArgs,
@@ -29,7 +26,6 @@ from bytelatent.generate import (
29
  PackedCausalTransformerGenerator,
30
  load_consolidated_model_and_tokenizer,
31
  )
32
- from bytelatent.transformer import LMTransformer, LMTransformerArgs
33
 
34
  EVAL_FOLDER_NAME = "{:010d}"
35
 
 
5
  import os
6
  from collections import defaultdict
7
  from datetime import datetime
 
 
8
 
9
  import torch
10
  from lm_eval import simple_evaluate
11
  from lm_eval.api.instance import Instance
12
  from lm_eval.api.model import LM
 
 
13
 
14
+ from bytelatent.args import EvalArgs, ValidationArgs
15
  from bytelatent.checkpoint import CONSOLIDATE_FOLDER, consolidate_checkpoints
16
+ from bytelatent.config_parser import parse_args_to_pydantic_model
17
  from bytelatent.data.file_util import get_fs
18
  from bytelatent.distributed import (
19
  DistributedArgs,
 
26
  PackedCausalTransformerGenerator,
27
  load_consolidated_model_and_tokenizer,
28
  )
 
29
 
30
  EVAL_FOLDER_NAME = "{:010d}"
31
 
bytelatent/print_config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bytelatent.args import TrainArgs
2
+ from bytelatent.config_parser import parse_args_to_pydantic_model
3
+
4
+
5
+ def main():
6
+ train_args = parse_args_to_pydantic_model(TrainArgs)
7
+ print(train_args.model_dump_json(indent=4))
8
+
9
+
10
+ if __name__ == "__main__":
11
+ main()
bytelatent/test_config_parser.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pytest
4
+ from omegaconf import DictConfig, MissingMandatoryValue, OmegaConf
5
+ from pydantic import BaseModel, ConfigDict
6
+
7
+ from bytelatent.config_parser import (
8
+ parse_args_to_pydantic_model,
9
+ parse_file_config,
10
+ recursively_parse_config,
11
+ )
12
+
13
+ FIXTURE_DIR = "fixtures/test-cfgs"
14
+
15
+
16
+ def test_parse_file_config():
17
+ with pytest.raises(ValueError):
18
+ cfg = parse_file_config(os.path.join(FIXTURE_DIR, "list.yaml"))
19
+ assert isinstance(cfg, DictConfig)
20
+
21
+
22
+ def test_nop():
23
+ cfg = OmegaConf.create({"a": 1})
24
+ parsed_cfgs = recursively_parse_config(cfg)
25
+ assert len(parsed_cfgs) == 1
26
+ assert parsed_cfgs[0] == cfg
27
+
28
+
29
+ def test_root():
30
+ cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "root.yaml")})
31
+ parsed_cfgs = recursively_parse_config(cli_cfg)
32
+ assert len(parsed_cfgs) == 2
33
+ assert len(parsed_cfgs[1]) == 0
34
+ assert parsed_cfgs[0]["seed"] == -1
35
+ with pytest.raises(MissingMandatoryValue):
36
+ assert parsed_cfgs[0]["b"]["y"] is not None
37
+
38
+ # Test basic cli override
39
+ cli_cfg = OmegaConf.create(
40
+ {"config": os.path.join(FIXTURE_DIR, "root.yaml"), "seed": 42}
41
+ )
42
+ parsed_cfgs = recursively_parse_config(cli_cfg)
43
+ assert parsed_cfgs[1]["seed"] == 42
44
+ cfg = OmegaConf.merge(*parsed_cfgs)
45
+ assert cfg["seed"] == 42
46
+
47
+
48
+ def test_one_level_include():
49
+ cli_cfg = OmegaConf.create({"config": os.path.join(FIXTURE_DIR, "middle.yaml")})
50
+ parsed_cfgs = recursively_parse_config(cli_cfg)
51
+ assert len(parsed_cfgs) == 3
52
+ assert parsed_cfgs[0]["seed"] == -1
53
+ assert parsed_cfgs[1]["b"]["y"] == 10
54
+ assert len(parsed_cfgs[2]) == 0
55
+ cfg = OmegaConf.merge(*parsed_cfgs)
56
+ assert cfg["b"]["y"] == 10
57
+
58
+ cli_cfg = OmegaConf.create(
59
+ {"config": os.path.join(FIXTURE_DIR, "middle.yaml"), "b": {"y": 100}}
60
+ )
61
+ parsed_cfgs = recursively_parse_config(cli_cfg)
62
+ assert len(parsed_cfgs) == 3
63
+ assert parsed_cfgs[0]["seed"] == -1
64
+ assert parsed_cfgs[1]["b"]["y"] == 10
65
+ assert parsed_cfgs[2]["b"]["y"] == 100
66
+ cfg = OmegaConf.merge(*parsed_cfgs)
67
+ assert cfg["b"]["y"] == 100
68
+
69
+
70
+ def test_two_level_include():
71
+ cli_cfg = OmegaConf.create(
72
+ {"config": os.path.join(FIXTURE_DIR, "top.yaml"), "p": 500, "b": {"z": -2}}
73
+ )
74
+ parsed_cfgs = recursively_parse_config(cli_cfg)
75
+ assert len(parsed_cfgs) == 4
76
+ assert parsed_cfgs[0]["seed"] == -1
77
+ assert parsed_cfgs[1]["b"]["y"] == 10
78
+ assert parsed_cfgs[2]["hello"] == "world"
79
+ assert parsed_cfgs[3]["p"] == 500
80
+ assert parsed_cfgs[3]["b"]["z"] == -2
81
+ cfg = OmegaConf.merge(*parsed_cfgs)
82
+ assert cfg["a"] == 1
83
+ assert cfg["seed"] == -1
84
+ assert cfg["b"]["x"] == 0
85
+ assert cfg["b"]["y"] == 10
86
+ assert cfg["b"]["z"] == -2
87
+ assert cfg["hello"] == "world"
88
+
89
+
90
+ def test_multiple_includes():
91
+ cli_cfg = OmegaConf.create(
92
+ {
93
+ "config": [
94
+ os.path.join(FIXTURE_DIR, "top.yaml"),
95
+ os.path.join(FIXTURE_DIR, "override.yaml"),
96
+ ],
97
+ "p": 500,
98
+ "b": {"z": -2},
99
+ }
100
+ )
101
+ parsed_cfgs = recursively_parse_config(cli_cfg)
102
+ assert len(parsed_cfgs) == 5
103
+ assert parsed_cfgs[0]["seed"] == -1
104
+ assert parsed_cfgs[1]["b"]["y"] == 10
105
+ assert parsed_cfgs[2]["hello"] == "world"
106
+ assert parsed_cfgs[3]["a"] == 100
107
+ assert parsed_cfgs[4]["p"] == 500
108
+ assert parsed_cfgs[4]["b"]["z"] == -2
109
+ cfg = OmegaConf.merge(*parsed_cfgs)
110
+ assert cfg["a"] == 100
111
+ assert cfg["seed"] == -1
112
+ assert cfg["b"]["x"] == 0
113
+ assert cfg["b"]["y"] == 10
114
+ assert cfg["b"]["z"] == -2
115
+ assert cfg["hello"] == "world"
116
+
117
+ cli_cfg = OmegaConf.create(
118
+ {
119
+ "config": [
120
+ os.path.join(FIXTURE_DIR, "top.yaml"),
121
+ os.path.join(FIXTURE_DIR, "override.yaml"),
122
+ ],
123
+ "p": 500,
124
+ "b": {"z": -2},
125
+ "a": 1000,
126
+ }
127
+ )
128
+ parsed_cfgs = recursively_parse_config(cli_cfg)
129
+ assert len(parsed_cfgs) == 5
130
+ assert parsed_cfgs[0]["seed"] == -1
131
+ assert parsed_cfgs[1]["b"]["y"] == 10
132
+ assert parsed_cfgs[2]["hello"] == "world"
133
+ assert parsed_cfgs[3]["a"] == 100
134
+ assert parsed_cfgs[4]["p"] == 500
135
+ assert parsed_cfgs[4]["b"]["z"] == -2
136
+ cfg = OmegaConf.merge(*parsed_cfgs)
137
+ assert cfg["a"] == 1000
138
+ assert cfg["seed"] == -1
139
+ assert cfg["b"]["x"] == 0
140
+ assert cfg["b"]["y"] == 10
141
+ assert cfg["b"]["z"] == -2
142
+ assert cfg["hello"] == "world"
143
+
144
+
145
+ class SubConfig(BaseModel):
146
+ model_config = ConfigDict(extra="forbid")
147
+ x: int = -100
148
+ y: int = -100
149
+ z: int = -5
150
+
151
+
152
+ class SampleConfig(BaseModel):
153
+ model_config = ConfigDict(extra="forbid")
154
+ a: int = -100
155
+ seed: int = -100
156
+ b: SubConfig = SubConfig()
157
+ hello: str = ""
158
+ p: int = -100
159
+
160
+
161
+ def test_pydantic_parse():
162
+ cli_cfg = OmegaConf.create(
163
+ {
164
+ "config": [
165
+ os.path.join(FIXTURE_DIR, "top.yaml"),
166
+ os.path.join(FIXTURE_DIR, "override.yaml"),
167
+ ],
168
+ "p": 500,
169
+ "a": 1000,
170
+ }
171
+ )
172
+ cfg = parse_args_to_pydantic_model(SampleConfig, cli_args=cli_cfg)
173
+ assert isinstance(cfg, SampleConfig)
174
+ assert cfg.a == 1000
175
+ assert cfg.p == 500
176
+ assert cfg.seed == -1
177
+ assert cfg.b.x == 0
178
+ assert cfg.b.y == 10
179
+ assert cfg.b.z == -5
180
+ assert cfg.hello == "world"
bytelatent/train.py CHANGED
@@ -23,8 +23,9 @@ from torch.distributed._tensor import DTensor
23
  from torch.distributed.checkpoint.stateful import Stateful
24
  from torch.optim import lr_scheduler
25
 
26
- from bytelatent.args import TrainArgs, parse_args
27
  from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
 
28
  from bytelatent.data.file_util import get_fs
29
  from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
30
  from bytelatent.data.iterators.multiprocess_iterator import (
@@ -824,7 +825,7 @@ def main():
824
 
825
  Plus all the default values in TrainArgs dataclass.
826
  """
827
- train_args = parse_args(TrainArgs)
828
  if train_args.debug_dynamo:
829
  import torch._dynamo
830
 
 
23
  from torch.distributed.checkpoint.stateful import Stateful
24
  from torch.optim import lr_scheduler
25
 
26
+ from bytelatent.args import TrainArgs
27
  from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
28
+ from bytelatent.config_parser import parse_args_to_pydantic_model
29
  from bytelatent.data.file_util import get_fs
30
  from bytelatent.data.iterators.abstract_iterator import get_state_and_refresh
31
  from bytelatent.data.iterators.multiprocess_iterator import (
 
825
 
826
  Plus all the default values in TrainArgs dataclass.
827
  """
828
+ train_args = parse_args_to_pydantic_model(TrainArgs)
829
  if train_args.debug_dynamo:
830
  import torch._dynamo
831
 
fixtures/test-cfgs/list.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ [1, 2, 3]
fixtures/test-cfgs/middle.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ config: fixtures/test-cfgs/root.yaml
2
+ b:
3
+ y: 10
fixtures/test-cfgs/override.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ a: 100
fixtures/test-cfgs/root.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ seed: -1
2
+ a: 1
3
+ b:
4
+ x: 0
5
+ y: ???
6
+ z: ???
fixtures/test-cfgs/top.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ config: fixtures/test-cfgs/middle.yaml
2
+
3
+ hello: world