Spaces:
Running
on
Zero
Running
on
Zero
Fix distributed all reduce grad norm (#40)
Browse filesSummary:
With >1 GPU, but only 1 node, all reduces fail when inputs are not bf16. This uses a modified copy of torch's grad norm to avoid failures
Test Plan:
- Run unit tests:
- Run single gpu training: `python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100`
- Run 1 node, multi-gpu training `torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100`
- bytelatent/norms.py +100 -0
- bytelatent/train.py +32 -3
bytelatent/norms.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
from torch.utils._foreach_utils import (
|
6 |
+
_device_has_foreach_support,
|
7 |
+
_group_tensors_by_device_and_dtype,
|
8 |
+
_has_foreach_support,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
@torch.no_grad()
|
13 |
+
def fixed_clip_grad_norm_(
|
14 |
+
parameters: torch.Tensor | list[torch.Tensor],
|
15 |
+
max_norm: float,
|
16 |
+
norm_type: float = 2.0,
|
17 |
+
error_if_nonfinite: bool = False,
|
18 |
+
foreach: Optional[bool] = None,
|
19 |
+
) -> torch.Tensor:
|
20 |
+
r"""Clip the gradient norm of an iterable of parameters.
|
21 |
+
|
22 |
+
The norm is computed over the norms of the individual gradients of all parameters,
|
23 |
+
as if the norms of the individual gradients were concatenated into a single vector.
|
24 |
+
Gradients are modified in-place.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
28 |
+
single Tensor that will have gradients normalized
|
29 |
+
max_norm (float): max norm of the gradients
|
30 |
+
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
31 |
+
infinity norm.
|
32 |
+
error_if_nonfinite (bool): if True, an error is thrown if the total
|
33 |
+
norm of the gradients from :attr:`parameters` is ``nan``,
|
34 |
+
``inf``, or ``-inf``. Default: False (will switch to True in the future)
|
35 |
+
foreach (bool): use the faster foreach-based implementation.
|
36 |
+
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
37 |
+
fall back to the slow implementation for other device types.
|
38 |
+
Default: ``None``
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
Total norm of the parameter gradients (viewed as a single vector).
|
42 |
+
"""
|
43 |
+
if isinstance(parameters, torch.Tensor):
|
44 |
+
parameters = [parameters]
|
45 |
+
grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
|
46 |
+
max_norm = float(max_norm)
|
47 |
+
norm_type = float(norm_type)
|
48 |
+
if len(grads) == 0:
|
49 |
+
return torch.tensor(0.0)
|
50 |
+
first_device = grads[0].device
|
51 |
+
grouped_grads: Dict[
|
52 |
+
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
53 |
+
] = _group_tensors_by_device_and_dtype(
|
54 |
+
[grads]
|
55 |
+
) # type: ignore[assignment]
|
56 |
+
|
57 |
+
norms: List[Tensor] = []
|
58 |
+
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
59 |
+
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
60 |
+
foreach and _device_has_foreach_support(device)
|
61 |
+
):
|
62 |
+
norms.extend(torch._foreach_norm(device_grads, norm_type))
|
63 |
+
elif foreach:
|
64 |
+
raise RuntimeError(
|
65 |
+
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
|
69 |
+
|
70 |
+
total_norm = torch.linalg.vector_norm(
|
71 |
+
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
72 |
+
)
|
73 |
+
|
74 |
+
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
75 |
+
raise RuntimeError(
|
76 |
+
f"The total norm of order {norm_type} for gradients from "
|
77 |
+
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
78 |
+
"this error and scale the gradients by the non-finite norm anyway, "
|
79 |
+
"set `error_if_nonfinite=False`"
|
80 |
+
)
|
81 |
+
clip_coef = max_norm / (total_norm + 1e-6)
|
82 |
+
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
83 |
+
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
84 |
+
# when the gradients do not reside in CPU memory.
|
85 |
+
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
86 |
+
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
87 |
+
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
88 |
+
foreach and _device_has_foreach_support(device)
|
89 |
+
):
|
90 |
+
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
|
91 |
+
elif foreach:
|
92 |
+
raise RuntimeError(
|
93 |
+
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
97 |
+
for g in device_grads:
|
98 |
+
g.mul_(clip_coef_clamped_device)
|
99 |
+
|
100 |
+
return total_norm
|
bytelatent/train.py
CHANGED
@@ -47,6 +47,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
|
|
47 |
from bytelatent.logger import init_logger
|
48 |
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
|
49 |
from bytelatent.model.blt import ByteLatentTransformer
|
|
|
50 |
from bytelatent.optim import build_optimizer
|
51 |
from bytelatent.probe import AutoProbeD
|
52 |
from bytelatent.profiling import maybe_run_profiler
|
@@ -147,9 +148,26 @@ def validate_train_args(args: TrainArgs, output_size: int):
|
|
147 |
* args.distributed.tp_size
|
148 |
!= get_world_size()
|
149 |
):
|
|
|
150 |
assert get_world_size() % args.distributed.dp_shard == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
154 |
args.distributed.dp_replicate = (
|
155 |
args.distributed.dp_replicate // args.distributed.tp_size
|
@@ -470,9 +488,20 @@ def train(args: TrainArgs):
|
|
470 |
# For logging we undo that scaling
|
471 |
loss = loss.detach() * args.grad_acc_steps
|
472 |
|
473 |
-
|
474 |
-
|
475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
|
477 |
grad_norm = (
|
478 |
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
|
|
|
47 |
from bytelatent.logger import init_logger
|
48 |
from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
|
49 |
from bytelatent.model.blt import ByteLatentTransformer
|
50 |
+
from bytelatent.norms import fixed_clip_grad_norm_
|
51 |
from bytelatent.optim import build_optimizer
|
52 |
from bytelatent.probe import AutoProbeD
|
53 |
from bytelatent.profiling import maybe_run_profiler
|
|
|
148 |
* args.distributed.tp_size
|
149 |
!= get_world_size()
|
150 |
):
|
151 |
+
logging.info("Modifying TrainArgs distributed config")
|
152 |
assert get_world_size() % args.distributed.dp_shard == 0
|
153 |
+
logging.info("World size: %s", get_world_size())
|
154 |
+
logging.info(
|
155 |
+
"Existing setting: train_args.distributed.dp_shard=%s",
|
156 |
+
args.distributed.dp_shard,
|
157 |
+
)
|
158 |
+
logging.info(
|
159 |
+
"Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
|
160 |
+
get_world_size() // args.distributed.dp_shard,
|
161 |
+
args.distributed.dp_replicate,
|
162 |
+
)
|
163 |
args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
|
164 |
|
165 |
+
logging.info(
|
166 |
+
"Changing dp_replicate from %s to %s, to account for tp_size=%s",
|
167 |
+
args.distributed.dp_replicate,
|
168 |
+
args.distributed.dp_replicate // args.distributed.tp_size,
|
169 |
+
args.distributed.tp_size,
|
170 |
+
)
|
171 |
assert args.distributed.dp_replicate % args.distributed.tp_size == 0
|
172 |
args.distributed.dp_replicate = (
|
173 |
args.distributed.dp_replicate // args.distributed.tp_size
|
|
|
488 |
# For logging we undo that scaling
|
489 |
loss = loss.detach() * args.grad_acc_steps
|
490 |
|
491 |
+
world_size = get_world_size()
|
492 |
+
if 1 < world_size <= 8:
|
493 |
+
# For some reason, there are errors in reduces due to
|
494 |
+
# not working for non-bf16 numbers. This function is a patched
|
495 |
+
# version that converts gradients to bf16 before computing norms.
|
496 |
+
# The error only happens in distributed training on one node,
|
497 |
+
# hence the guard
|
498 |
+
grad_norm = fixed_clip_grad_norm_(
|
499 |
+
model.parameters(), max_norm=args.optim.clip, foreach=True
|
500 |
+
)
|
501 |
+
else:
|
502 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
503 |
+
model.parameters(), max_norm=args.optim.clip, foreach=True
|
504 |
+
)
|
505 |
|
506 |
grad_norm = (
|
507 |
grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
|