Srinivasan Iyer sviyer commited on
Commit
c110f6b
·
unverified ·
1 Parent(s): a5ceaaa

Add way to call consolidate (#80)

Browse files

* Add way to call consolidate

* black

* isort

---------

Co-authored-by: Srini Iyer <[email protected]>

Files changed (1) hide show
  1. bytelatent/checkpoint.py +18 -0
bytelatent/checkpoint.py CHANGED
@@ -12,6 +12,7 @@ import torch.distributed as dist
12
  import torch.distributed.checkpoint as dcp
13
  import torch.nn as nn
14
  import torch.optim.optimizer
 
15
  from pydantic import BaseModel, ConfigDict
16
  from torch.distributed._tensor import DeviceMesh
17
  from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
@@ -323,3 +324,20 @@ class CheckpointManager:
323
  dist.barrier()
324
 
325
  return cls(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  import torch.distributed.checkpoint as dcp
13
  import torch.nn as nn
14
  import torch.optim.optimizer
15
+ import typer
16
  from pydantic import BaseModel, ConfigDict
17
  from torch.distributed._tensor import DeviceMesh
18
  from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
 
324
  dist.barrier()
325
 
326
  return cls(args)
327
+
328
+
329
+ def main(
330
+ command: str,
331
+ model_checkpoint_dir: str,
332
+ ):
333
+ if command == "consolidate":
334
+ print(
335
+ f"Consolidating {model_checkpoint_dir}. Output will be in the {CONSOLIDATE_FOLDER} folder."
336
+ )
337
+ consolidate_checkpoints(fsspec.filesystem("file"), model_checkpoint_dir)
338
+ else:
339
+ raise ValueError("Invalid command")
340
+
341
+
342
+ if __name__ == "__main__":
343
+ typer.run(main)