par-meta commited on
Commit
afedb16
·
unverified ·
1 Parent(s): 739dc71

Update checkpointing to use fsspec (#39)

Browse files

Summary:

- Make the data/checkpoint code fsspec compatible
- Still will not work with s3 saves, due to `torch.distributed.checkpoint.save` not being out of the box workable with `fsspec`. Will implement in followup PR


Test Plan:

Run unit tests and the commands below

```
python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100
```

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100
```

These currently won't work due to the torch distributed save, but theses hould be tested at a later date

```
python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/
```

```
torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100 dump_dir=s3://blt/scratch/checkpoint-test/
```

bytelatent/args.py CHANGED
@@ -294,6 +294,14 @@ class TrainArgs(BaseModel):
294
  def dump_to_yaml_file(
295
  self, path: str, log_config: bool = True, sort_keys: bool = True
296
  ):
 
 
 
 
 
 
 
 
297
  model_dict = self.model_dump(mode="json")
298
  yaml_str = yaml.dump(
299
  model_dict,
@@ -301,8 +309,4 @@ class TrainArgs(BaseModel):
301
  sort_keys=sort_keys,
302
  default_flow_style=False,
303
  )
304
- with open(path, "w") as f:
305
- if log_config:
306
- logger.info("Using the following config for this run:")
307
- logger.info(yaml_str)
308
- f.write(yaml_str)
 
294
  def dump_to_yaml_file(
295
  self, path: str, log_config: bool = True, sort_keys: bool = True
296
  ):
297
+ yaml_str = self.dump_to_yaml_str(sort_keys=sort_keys)
298
+ with open(path, "w") as f:
299
+ if log_config:
300
+ logger.info("Using the following config for this run:")
301
+ logger.info(yaml_str)
302
+ f.write(yaml_str)
303
+
304
+ def dump_to_yaml_str(self, sort_keys: bool = True):
305
  model_dict = self.model_dump(mode="json")
306
  yaml_str = yaml.dump(
307
  model_dict,
 
309
  sort_keys=sort_keys,
310
  default_flow_style=False,
311
  )
312
+ return yaml_str
 
 
 
 
bytelatent/checkpoint.py CHANGED
@@ -4,10 +4,9 @@ import json
4
  import logging
5
  import os
6
  import re
7
- from pathlib import Path
8
- from typing import List, Optional, Tuple
9
 
10
  import fsspec
 
11
  import torch
12
  import torch.distributed as dist
13
  import torch.distributed.checkpoint as dcp
@@ -70,26 +69,29 @@ def consolidate_checkpoints(fs: fsspec.AbstractFileSystem, ckpt_dir: str):
70
 
71
  Returns the path to the consolidated checkpoint
72
  """
73
- consolidate_path = Path(ckpt_dir) / CONSOLIDATE_FOLDER
74
- if not (consolidate_path / CONSOLIDATE_NAME).exists():
75
- consolidate_path.mkdir(exist_ok=True)
76
- logger.info(f"Consolidating to: {str(consolidate_path)}")
77
- dcp_to_torch_save(ckpt_dir, str(consolidate_path / CONSOLIDATE_NAME))
78
- (consolidate_path / CONFIG_NAME).write_text(
79
- (Path(ckpt_dir) / CONFIG_NAME).read_text()
 
 
80
  )
81
  logger.info("Consolidated !")
82
  return consolidate_path
83
 
84
 
85
  def load_from_checkpoint(
 
86
  ckpt_dir: str,
87
  model: nn.Module,
88
- optimizer: Optional[torch.optim.Optimizer] = None,
89
  model_key: str = "model",
90
  optim_key: str = "optim",
91
  ):
92
- if not (Path(ckpt_dir) / ".metadata").exists():
93
  raise ValueError(
94
  f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
95
  )
@@ -115,19 +117,24 @@ class CheckpointManager:
115
  self.init_ckpt_path = args.init_ckpt_path
116
  self.continue_training_from_init = args.continue_training_from_init
117
 
118
- assert self.fs.exists(
119
- self.path
120
- ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
 
 
121
 
122
  self.existing_saves = self.get_existing_saves()
123
 
124
- def get_existing_saves(self) -> List[Path]:
125
- folders = [
126
- p
127
- for p in Path(self.path).iterdir()
128
- if p.is_dir() and re.match(RE_FOLDER, p.name)
129
- ]
130
- folders.sort(key=lambda p: _get_key_step(p.name))
 
 
 
131
  return folders
132
 
133
  def clean_up(self):
@@ -136,8 +143,9 @@ class CheckpointManager:
136
  eval_folders = []
137
  other_folders = []
138
  for p in self.existing_saves:
139
- is_dump = _get_key_step(p.name) % self.dump_every.every == 0
140
- is_eval = _get_key_step(p.name) % self.eval_every.every == 0
 
141
  if is_dump:
142
  dump_folders.append(p)
143
  if is_eval:
@@ -161,40 +169,39 @@ class CheckpointManager:
161
 
162
  if dist.get_rank() == 0:
163
  for folder in folder_to_remove:
164
- for file in folder.iterdir():
165
- if file.is_file():
166
- file.unlink()
167
- elif file.is_dir():
168
- assert file.name in [CONSOLIDATE_FOLDER]
169
- for f in file.iterdir():
170
- f.unlink()
171
- file.rmdir()
172
- folder.rmdir()
173
 
174
  dist.barrier()
175
 
176
  self.existing_saves = list(folder_to_keep)
177
- self.existing_saves.sort(key=lambda p: _get_key_step(p.name))
178
 
179
- def get_last_step_path(self, dp_rank: int = 0) -> Optional[Path]:
180
  path = None
181
  for p in reversed(self.existing_saves):
182
- if (p / TRAIN_STATE_NAME.format(dp_rank)).is_file():
 
183
  path = p
184
  break
185
  return path
186
 
187
- def _create_folder(self, base_path: Path, folder_name: str) -> Path:
188
- folder = base_path / folder_name
189
  if get_is_master():
190
- folder.mkdir(parents=False, exist_ok=True)
191
  if dist.is_initialized():
192
  dist.barrier()
193
  return folder
194
 
195
- def _get_dp_tp_mesh(
196
- self, device_mesh: Optional[DeviceMesh] = None
197
- ) -> Tuple[int, int]:
198
  dp_rank = 0
199
  tp_rank = 0
200
  if device_mesh is not None:
@@ -222,14 +229,14 @@ class CheckpointManager:
222
  model,
223
  optimizer,
224
  train_state,
225
- config,
226
- device_mesh: Optional[DeviceMesh] = None,
227
  ) -> bool:
228
 
229
  # When creating directory check if only rank0 or is there other solution
230
- path = Path(self.path)
231
  curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
232
- logger.info(f"Saving to: {str(curr_save_dir)}")
233
 
234
  if dist.is_initialized():
235
  dist.barrier()
@@ -242,17 +249,19 @@ class CheckpointManager:
242
  if dist.is_initialized():
243
  dist.barrier()
244
 
 
245
  if get_is_master():
246
- config.dump_to_yaml_file(curr_save_dir / CONFIG_NAME)
 
 
247
 
248
  # Add json dump here
249
  dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
250
  if tp_rank == 0:
251
  train_state_name = TRAIN_STATE_NAME.format(dp_rank)
252
- logger.info(
253
- f"Saving train state to: {str(curr_save_dir / train_state_name)}"
254
- )
255
- with open(curr_save_dir / train_state_name, "w") as f:
256
  json.dump(train_state.state_dict(), f)
257
  logger.info("Train state saved !")
258
 
@@ -271,7 +280,7 @@ class CheckpointManager:
271
  optimizer,
272
  train_state,
273
  device_mesh: DeviceMesh,
274
- path: Optional[Path] = None,
275
  ):
276
  dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
277
  # Loading tries to load the provided path, if not available the last saved step and finally from the init path
@@ -284,12 +293,12 @@ class CheckpointManager:
284
  # Only load train state if it's provided, the files exist and we're not loading from init path
285
  train_state_name = TRAIN_STATE_NAME.format(dp_rank)
286
  logger.info("Reloading train state")
287
- with open(path / train_state_name, "r") as f:
288
  train_state_dict = json.load(f)
289
  train_state.load_state_dict(train_state_dict)
290
  logger.info("Train state reloaded")
291
 
292
- logger.info(f"Loading from: {str(path)}")
293
  state_dict = self.get_state_dict(
294
  model=model,
295
  optimizer=optimizer,
 
4
  import logging
5
  import os
6
  import re
 
 
7
 
8
  import fsspec
9
+ import s3fs
10
  import torch
11
  import torch.distributed as dist
12
  import torch.distributed.checkpoint as dcp
 
69
 
70
  Returns the path to the consolidated checkpoint
71
  """
72
+ consolidate_path = os.path.join(ckpt_dir, CONSOLIDATE_FOLDER)
73
+ consolidate_name = os.path.join(consolidate_path, CONSOLIDATE_NAME)
74
+ if not fs.exists(consolidate_name):
75
+ fs.mkdirs(consolidate_path, exist_ok=True)
76
+ logger.info(f"Consolidating to: {consolidate_path}")
77
+ dcp_to_torch_save(ckpt_dir, consolidate_name)
78
+ fs.write_text(
79
+ os.path.join(consolidate_path, CONFIG_NAME),
80
+ fs.read_text(os.path.join(ckpt_dir, CONFIG_NAME)),
81
  )
82
  logger.info("Consolidated !")
83
  return consolidate_path
84
 
85
 
86
  def load_from_checkpoint(
87
+ fs: fsspec.AbstractFileSystem,
88
  ckpt_dir: str,
89
  model: nn.Module,
90
+ optimizer: torch.optim.Optimizer | None = None,
91
  model_key: str = "model",
92
  optim_key: str = "optim",
93
  ):
94
+ if not fs.exists(os.path.join(ckpt_dir, ".metadata")):
95
  raise ValueError(
96
  f"Please convert the checkpoint distcp format using `torch.distributed.checkpoint.format_utils.torch_save_to_dcp` before loading it"
97
  )
 
117
  self.init_ckpt_path = args.init_ckpt_path
118
  self.continue_training_from_init = args.continue_training_from_init
119
 
120
+ if not isinstance(self.fs, s3fs.S3FileSystem):
121
+ # S3 does not have a concept of directories
122
+ assert self.fs.exists(
123
+ self.path
124
+ ), f"Path {self.path} does not exist and needs to be created before using CheckpointManager (use instantiate_and_make_dir)"
125
 
126
  self.existing_saves = self.get_existing_saves()
127
 
128
+ def get_existing_saves(self) -> list[str]:
129
+ if self.fs.exists(self.path) and self.fs.isdir(self.path):
130
+ folders = [
131
+ p
132
+ for p in self.fs.ls(self.path)
133
+ if self.fs.isdir(p) and re.match(RE_FOLDER, os.path.basename(p))
134
+ ]
135
+ else:
136
+ folders = []
137
+ folders.sort(key=lambda p: _get_key_step(os.path.basename(p)))
138
  return folders
139
 
140
  def clean_up(self):
 
143
  eval_folders = []
144
  other_folders = []
145
  for p in self.existing_saves:
146
+ assert isinstance(p, str), f"Base path type: {p}"
147
+ is_dump = _get_key_step(os.path.basename(p)) % self.dump_every.every == 0
148
+ is_eval = _get_key_step(os.path.basename(p)) % self.eval_every.every == 0
149
  if is_dump:
150
  dump_folders.append(p)
151
  if is_eval:
 
169
 
170
  if dist.get_rank() == 0:
171
  for folder in folder_to_remove:
172
+ for file in self.fs.ls(folder):
173
+ if self.fs.isfile(file):
174
+ self.fs.rm_file(file)
175
+ elif self.fs.isdir(file):
176
+ assert os.path.name(file) in [CONSOLIDATE_FOLDER]
177
+ for f in self.fs.ls(file):
178
+ self.fs.rm(f)
179
+ self.fs.rmdir(file)
180
+ self.fs.rmdir(folder)
181
 
182
  dist.barrier()
183
 
184
  self.existing_saves = list(folder_to_keep)
185
+ self.existing_saves.sort(key=lambda p: _get_key_step(os.path.basename(p)))
186
 
187
+ def get_last_step_path(self, dp_rank: int = 0) -> str | None:
188
  path = None
189
  for p in reversed(self.existing_saves):
190
+
191
+ if self.fs.isfile(os.path.join(p, TRAIN_STATE_NAME.format(dp_rank))):
192
  path = p
193
  break
194
  return path
195
 
196
+ def _create_folder(self, base_path: str, folder_name: str) -> str:
197
+ folder = os.path.join(base_path, folder_name)
198
  if get_is_master():
199
+ self.fs.mkdirs(folder, exist_ok=True)
200
  if dist.is_initialized():
201
  dist.barrier()
202
  return folder
203
 
204
+ def _get_dp_tp_mesh(self, device_mesh: DeviceMesh | None = None) -> tuple[int, int]:
 
 
205
  dp_rank = 0
206
  tp_rank = 0
207
  if device_mesh is not None:
 
229
  model,
230
  optimizer,
231
  train_state,
232
+ config: BaseModel,
233
+ device_mesh: DeviceMesh | None = None,
234
  ) -> bool:
235
 
236
  # When creating directory check if only rank0 or is there other solution
237
+ path = self.path
238
  curr_save_dir = self._create_folder(path, FOLDER_NAME.format(train_state.step))
239
+ logger.info(f"Saving to: {curr_save_dir}")
240
 
241
  if dist.is_initialized():
242
  dist.barrier()
 
249
  if dist.is_initialized():
250
  dist.barrier()
251
 
252
+ print("config type", type(config))
253
  if get_is_master():
254
+ self.fs.write_text(
255
+ os.path.join(curr_save_dir, CONFIG_NAME), config.model_dump_json()
256
+ )
257
 
258
  # Add json dump here
259
  dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
260
  if tp_rank == 0:
261
  train_state_name = TRAIN_STATE_NAME.format(dp_rank)
262
+ train_state_full_path = os.path.join(curr_save_dir, train_state_name)
263
+ logger.info(f"Saving train state to: {train_state_full_path}")
264
+ with self.fs.open(train_state_full_path, "w") as f:
 
265
  json.dump(train_state.state_dict(), f)
266
  logger.info("Train state saved !")
267
 
 
280
  optimizer,
281
  train_state,
282
  device_mesh: DeviceMesh,
283
+ path: str | None = None,
284
  ):
285
  dp_rank, tp_rank = self._get_dp_tp_mesh(device_mesh)
286
  # Loading tries to load the provided path, if not available the last saved step and finally from the init path
 
293
  # Only load train state if it's provided, the files exist and we're not loading from init path
294
  train_state_name = TRAIN_STATE_NAME.format(dp_rank)
295
  logger.info("Reloading train state")
296
+ with self.fs.open(os.path.join(path, train_state_name), "r") as f:
297
  train_state_dict = json.load(f)
298
  train_state.load_state_dict(train_state_dict)
299
  logger.info("Train state reloaded")
300
 
301
+ logger.info(f"Loading from: {path}")
302
  state_dict = self.get_state_dict(
303
  model=model,
304
  optimizer=optimizer,
bytelatent/logger.py CHANGED
@@ -6,6 +6,8 @@ import sys
6
  import time
7
  from datetime import timedelta
8
 
 
 
9
  from bytelatent.distributed import get_global_rank, get_is_slurm_job
10
 
11
 
@@ -92,6 +94,7 @@ def init_logger(
92
  *,
93
  name: str | None = None,
94
  level: str = "INFO",
 
95
  ):
96
  """
97
  Setup logging.
@@ -121,7 +124,11 @@ def init_logger(
121
 
122
  if log_file is not None and get_global_rank() == 0:
123
  # build file handler
124
- file_handler = logging.FileHandler(log_file, "a")
 
 
 
 
125
  file_handler.setLevel(logging.NOTSET)
126
  file_handler.setFormatter(LogFormatter())
127
  # update logger
 
6
  import time
7
  from datetime import timedelta
8
 
9
+ import fsspec
10
+
11
  from bytelatent.distributed import get_global_rank, get_is_slurm_job
12
 
13
 
 
94
  *,
95
  name: str | None = None,
96
  level: str = "INFO",
97
+ fs: fsspec.AbstractFileSystem | None = None,
98
  ):
99
  """
100
  Setup logging.
 
124
 
125
  if log_file is not None and get_global_rank() == 0:
126
  # build file handler
127
+ if fs is None:
128
+ file_handler = logging.FileHandler(log_file, "a")
129
+ else:
130
+ file_stream = fs.open(log_file, mode="a")
131
+ file_handler = logging.StreamHandler(file_stream)
132
  file_handler.setLevel(logging.NOTSET)
133
  file_handler.setFormatter(LogFormatter())
134
  # update logger
bytelatent/metrics.py CHANGED
@@ -8,6 +8,7 @@ from datetime import datetime, timezone
8
  from pathlib import Path
9
  from typing import Any, Union
10
 
 
11
  import torch
12
  import torch.nn as nn
13
  import wandb
@@ -53,14 +54,24 @@ class LoggingArgs(BaseModel):
53
 
54
 
55
  class MetricLogger:
56
- def __init__(self, outdir: Path, args: Any | None = None):
 
 
 
 
 
 
57
  self.outdir = outdir
58
  self.jsonl_writer = None
 
59
  self.args = args
60
 
61
  def open(self):
62
  if self.jsonl_writer is None:
63
- self.jsonl_writer = open(self.outdir, "a")
 
 
 
64
  if (
65
  self.args is not None
66
  and self.args.logging.wandb is not None
 
8
  from pathlib import Path
9
  from typing import Any, Union
10
 
11
+ import fsspec
12
  import torch
13
  import torch.nn as nn
14
  import wandb
 
54
 
55
 
56
  class MetricLogger:
57
+ def __init__(
58
+ self,
59
+ outdir: Path,
60
+ # args: TrainArgs
61
+ args: Any | None = None,
62
+ fs: fsspec.AbstractFileSystem | None = None,
63
+ ):
64
  self.outdir = outdir
65
  self.jsonl_writer = None
66
+ self.fs = fs
67
  self.args = args
68
 
69
  def open(self):
70
  if self.jsonl_writer is None:
71
+ if self.fs is None:
72
+ self.jsonl_writer = open(self.outdir, "a")
73
+ else:
74
+ self.jsonl_writer = self.fs.open(self.outdir, "a")
75
  if (
76
  self.args is not None
77
  and self.args.logging.wandb is not None
bytelatent/train.py CHANGED
@@ -8,7 +8,6 @@ import sys
8
  from contextlib import ExitStack
9
  from copy import deepcopy
10
  from dataclasses import asdict, dataclass
11
- from pathlib import Path
12
  from timeit import default_timer as timer
13
  from typing import Any, TypeVar
14
 
@@ -18,13 +17,13 @@ import torch.nn.functional
18
  import torch.nn.functional as F
19
  import wandb
20
  import xformers.profiler
21
- from omegaconf import OmegaConf
22
  from torch.distributed._tensor import DTensor
23
  from torch.distributed.checkpoint.stateful import Stateful
24
  from torch.optim import lr_scheduler
25
 
26
  from bytelatent.args import TrainArgs, parse_args
27
  from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
 
28
  from bytelatent.data.iterators.multiprocess_iterator import (
29
  MultiprocessIterator,
30
  MultiprocessIteratorState,
@@ -136,11 +135,12 @@ def validate_train_args(args: TrainArgs, output_size: int):
136
 
137
  if args.checkpoint.path is None:
138
  logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
139
- args.checkpoint.path = str(Path(args.dump_dir) / "checkpoints")
140
 
 
141
  for source in args.data.sources:
142
  data_path = os.path.join(args.data.root_dir, source)
143
- assert os.path.exists(data_path), f"{data_path} doesn't exist"
144
 
145
  if (
146
  args.distributed.dp_replicate
@@ -255,10 +255,15 @@ def train(args: TrainArgs):
255
  args,
256
  tokenizer.n_words,
257
  )
 
258
  if get_is_master():
259
- os.makedirs(args.dump_dir, exist_ok=True)
260
- args.dump_to_yaml_file(Path(args.dump_dir) / "config.yaml")
261
- init_logger(Path(args.dump_dir) / "train.log")
 
 
 
 
262
  init_signal_handler(set_preemption_flag) # For handling preemption signals.
263
  setup_env(args.env)
264
  setup_torch_distributed(args.distributed)
@@ -313,8 +318,11 @@ def train(args: TrainArgs):
313
 
314
  if args.checkpoint.init_ckpt_path:
315
  logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
 
 
 
316
  load_from_checkpoint(
317
- args.checkpoint.init_ckpt_path, model, model_key="model"
318
  ) # Put model_key="" if its directly the model checkpoint
319
  model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
320
  else:
@@ -352,13 +360,14 @@ def train(args: TrainArgs):
352
  checkpoint.load(model, optimizer, train_state, world_mesh)
353
  # Either load from latest checkpoint or start from scratch
354
  if args.probe_freq is not None:
 
355
  if get_is_master():
356
- os.makedirs(Path(args.dump_dir) / "probe", exist_ok=True)
357
  torch.distributed.barrier()
358
  probe = AutoProbeD(
359
  model,
360
  (
361
- Path(args.dump_dir) / "probe" / f"probe.{dp_rank}.jsonl"
362
  if (dp_rank % 128 == 0)
363
  else None
364
  ),
@@ -370,7 +379,7 @@ def train(args: TrainArgs):
370
  # train loop
371
  model.train()
372
  metric_logger = context_stack.enter_context(
373
- MetricLogger(Path(args.dump_dir) / "metrics.jsonl", args)
374
  )
375
  data_loader = train_state.data_loader_state.build()
376
  batch_iterator = data_loader.create_iter()
 
8
  from contextlib import ExitStack
9
  from copy import deepcopy
10
  from dataclasses import asdict, dataclass
 
11
  from timeit import default_timer as timer
12
  from typing import Any, TypeVar
13
 
 
17
  import torch.nn.functional as F
18
  import wandb
19
  import xformers.profiler
 
20
  from torch.distributed._tensor import DTensor
21
  from torch.distributed.checkpoint.stateful import Stateful
22
  from torch.optim import lr_scheduler
23
 
24
  from bytelatent.args import TrainArgs, parse_args
25
  from bytelatent.checkpoint import CheckpointManager, load_from_checkpoint
26
+ from bytelatent.data.file_util import get_fs
27
  from bytelatent.data.iterators.multiprocess_iterator import (
28
  MultiprocessIterator,
29
  MultiprocessIteratorState,
 
135
 
136
  if args.checkpoint.path is None:
137
  logger.info(f"Setting checkpoint path to {args.checkpoint.path}")
138
+ args.checkpoint.path = os.path.join(args.dump_dir, "checkpoints")
139
 
140
+ data_fs = get_fs(args.data.root_dir, s3_profile=args.data.s3_profile)
141
  for source in args.data.sources:
142
  data_path = os.path.join(args.data.root_dir, source)
143
+ assert data_fs.exists(data_path), f"{data_path} doesn't exist"
144
 
145
  if (
146
  args.distributed.dp_replicate
 
255
  args,
256
  tokenizer.n_words,
257
  )
258
+ dump_fs = get_fs(args.dump_dir, s3_profile=args.checkpoint.s3_profile)
259
  if get_is_master():
260
+ dump_fs.mkdirs(args.dump_dir, exist_ok=True)
261
+ config_yaml_str = args.dump_to_yaml_str()
262
+ logging.info("TrainArgs: \n%s", config_yaml_str)
263
+ dump_fs.write_text(
264
+ os.path.join(args.dump_dir, "config.yaml"), config_yaml_str
265
+ )
266
+ init_logger(os.path.join(args.dump_dir, "train.log"), fs=dump_fs)
267
  init_signal_handler(set_preemption_flag) # For handling preemption signals.
268
  setup_env(args.env)
269
  setup_torch_distributed(args.distributed)
 
318
 
319
  if args.checkpoint.init_ckpt_path:
320
  logger.info(f"Loading initial model from {args.checkpoint.init_ckpt_path}")
321
+ ckpt_fs = get_fs(
322
+ args.checkpoint.init_ckpt_path, s3_profile=args.checkpoint.s3_profile
323
+ )
324
  load_from_checkpoint(
325
+ ckpt_fs, args.checkpoint.init_ckpt_path, model, model_key="model"
326
  ) # Put model_key="" if its directly the model checkpoint
327
  model.rope_embeddings.reset_parameters() # For RoPe initialization since it's a buffer it might not be loaded
328
  else:
 
360
  checkpoint.load(model, optimizer, train_state, world_mesh)
361
  # Either load from latest checkpoint or start from scratch
362
  if args.probe_freq is not None:
363
+ # TODO: Convert this to fsspec compatible
364
  if get_is_master():
365
+ os.makedirs(os.path.join(args.dump_dir, "probe"), exist_ok=True)
366
  torch.distributed.barrier()
367
  probe = AutoProbeD(
368
  model,
369
  (
370
+ os.path.join(args.dump_dir, "probe", f"probe.{dp_rank}.jsonl")
371
  if (dp_rank % 128 == 0)
372
  else None
373
  ),
 
379
  # train loop
380
  model.train()
381
  metric_logger = context_stack.enter_context(
382
+ MetricLogger(os.path.join(args.dump_dir, "metrics.jsonl"), args, fs=dump_fs)
383
  )
384
  data_loader = train_state.data_loader_state.build()
385
  batch_iterator = data_loader.create_iter()