Spaces:
Running
on
Zero
Running
on
Zero
Init distributed when loading model (#94)
Browse filesCo-authored-by: Srini Iyer <[email protected]>
- bytelatent/distributed.py +1 -1
- bytelatent/generate.py +7 -1
bytelatent/distributed.py
CHANGED
@@ -301,7 +301,7 @@ def setup_torch_distributed(dist_args: DistributedArgs):
|
|
301 |
- global_rank
|
302 |
- world_size
|
303 |
"""
|
304 |
-
mp.set_start_method(dist_args.spawn_method)
|
305 |
with mp.Manager():
|
306 |
pass
|
307 |
|
|
|
301 |
- global_rank
|
302 |
- world_size
|
303 |
"""
|
304 |
+
mp.set_start_method(dist_args.spawn_method, force=True)
|
305 |
with mp.Manager():
|
306 |
pass
|
307 |
|
bytelatent/generate.py
CHANGED
@@ -25,7 +25,7 @@ from bytelatent.checkpoint import (
|
|
25 |
)
|
26 |
from bytelatent.config_parser import parse_args_to_pydantic_model
|
27 |
from bytelatent.data.file_util import get_fs
|
28 |
-
from bytelatent.distributed import get_global_rank
|
29 |
from bytelatent.model.blt import ByteLatentTransformer
|
30 |
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
31 |
from bytelatent.transformer import LMTransformer
|
@@ -390,7 +390,13 @@ class PackedCausalTransformerGenerator:
|
|
390 |
|
391 |
def load_consolidated_model_and_tokenizer(
|
392 |
consolidated_path,
|
|
|
393 |
):
|
|
|
|
|
|
|
|
|
|
|
394 |
train_args_path = os.path.join(consolidated_path, "params.json")
|
395 |
fs = get_fs(train_args_path)
|
396 |
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|
|
|
25 |
)
|
26 |
from bytelatent.config_parser import parse_args_to_pydantic_model
|
27 |
from bytelatent.data.file_util import get_fs
|
28 |
+
from bytelatent.distributed import get_global_rank, setup_torch_distributed, DistributedArgs
|
29 |
from bytelatent.model.blt import ByteLatentTransformer
|
30 |
from bytelatent.tokenizers.abstract_tokenizer import Tokenizer
|
31 |
from bytelatent.transformer import LMTransformer
|
|
|
390 |
|
391 |
def load_consolidated_model_and_tokenizer(
|
392 |
consolidated_path,
|
393 |
+
init_distributed=False
|
394 |
):
|
395 |
+
if init_distributed:
|
396 |
+
distributed_args = DistributedArgs()
|
397 |
+
distributed_args.configure_world()
|
398 |
+
if not torch.distributed.is_initialized():
|
399 |
+
setup_torch_distributed(distributed_args)
|
400 |
train_args_path = os.path.join(consolidated_path, "params.json")
|
401 |
fs = get_fs(train_args_path)
|
402 |
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
|