Srinivasan Iyer sviyer commited on
Commit
138c2f3
·
unverified ·
1 Parent(s): 19a3f75

Init distributed when loading model (#94)

Browse files

Co-authored-by: Srini Iyer <[email protected]>

bytelatent/distributed.py CHANGED
@@ -301,7 +301,7 @@ def setup_torch_distributed(dist_args: DistributedArgs):
301
  - global_rank
302
  - world_size
303
  """
304
- mp.set_start_method(dist_args.spawn_method)
305
  with mp.Manager():
306
  pass
307
 
 
301
  - global_rank
302
  - world_size
303
  """
304
+ mp.set_start_method(dist_args.spawn_method, force=True)
305
  with mp.Manager():
306
  pass
307
 
bytelatent/generate.py CHANGED
@@ -25,7 +25,7 @@ from bytelatent.checkpoint import (
25
  )
26
  from bytelatent.config_parser import parse_args_to_pydantic_model
27
  from bytelatent.data.file_util import get_fs
28
- from bytelatent.distributed import get_global_rank
29
  from bytelatent.model.blt import ByteLatentTransformer
30
  from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
31
  from bytelatent.transformer import LMTransformer
@@ -390,7 +390,13 @@ class PackedCausalTransformerGenerator:
390
 
391
  def load_consolidated_model_and_tokenizer(
392
  consolidated_path,
 
393
  ):
 
 
 
 
 
394
  train_args_path = os.path.join(consolidated_path, "params.json")
395
  fs = get_fs(train_args_path)
396
  train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
 
25
  )
26
  from bytelatent.config_parser import parse_args_to_pydantic_model
27
  from bytelatent.data.file_util import get_fs
28
+ from bytelatent.distributed import get_global_rank, setup_torch_distributed, DistributedArgs
29
  from bytelatent.model.blt import ByteLatentTransformer
30
  from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
31
  from bytelatent.transformer import LMTransformer
 
390
 
391
  def load_consolidated_model_and_tokenizer(
392
  consolidated_path,
393
+ init_distributed=False
394
  ):
395
+ if init_distributed:
396
+ distributed_args = DistributedArgs()
397
+ distributed_args.configure_world()
398
+ if not torch.distributed.is_initialized():
399
+ setup_torch_distributed(distributed_args)
400
  train_args_path = os.path.join(consolidated_path, "params.json")
401
  fs = get_fs(train_args_path)
402
  train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))