par-meta commited on
Commit
c79b1fd
·
unverified ·
1 Parent(s): 7044771

Fix distributed all reduce grad norm (#40)

Browse files

Summary:

With >1 GPU, but only 1 node, all reduces fail when inputs are not bf16. This uses a modified copy of torch's grad norm to avoid failures

Test Plan:

- Run unit tests:
- Run single gpu training: `python -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100`
- Run 1 node, multi-gpu training `torchrun --nproc-per-node 8 -m bytelatent.train config=internal/configs/s3_debug.yaml eval=null checkpoint.dump.every=100`

Files changed (2) hide show
  1. bytelatent/norms.py +100 -0
  2. bytelatent/train.py +32 -3
bytelatent/norms.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import Tensor
5
+ from torch.utils._foreach_utils import (
6
+ _device_has_foreach_support,
7
+ _group_tensors_by_device_and_dtype,
8
+ _has_foreach_support,
9
+ )
10
+
11
+
12
+ @torch.no_grad()
13
+ def fixed_clip_grad_norm_(
14
+ parameters: torch.Tensor | list[torch.Tensor],
15
+ max_norm: float,
16
+ norm_type: float = 2.0,
17
+ error_if_nonfinite: bool = False,
18
+ foreach: Optional[bool] = None,
19
+ ) -> torch.Tensor:
20
+ r"""Clip the gradient norm of an iterable of parameters.
21
+
22
+ The norm is computed over the norms of the individual gradients of all parameters,
23
+ as if the norms of the individual gradients were concatenated into a single vector.
24
+ Gradients are modified in-place.
25
+
26
+ Args:
27
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
28
+ single Tensor that will have gradients normalized
29
+ max_norm (float): max norm of the gradients
30
+ norm_type (float): type of the used p-norm. Can be ``'inf'`` for
31
+ infinity norm.
32
+ error_if_nonfinite (bool): if True, an error is thrown if the total
33
+ norm of the gradients from :attr:`parameters` is ``nan``,
34
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
35
+ foreach (bool): use the faster foreach-based implementation.
36
+ If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
37
+ fall back to the slow implementation for other device types.
38
+ Default: ``None``
39
+
40
+ Returns:
41
+ Total norm of the parameter gradients (viewed as a single vector).
42
+ """
43
+ if isinstance(parameters, torch.Tensor):
44
+ parameters = [parameters]
45
+ grads = [p.grad.to(torch.bfloat16) for p in parameters if p.grad is not None]
46
+ max_norm = float(max_norm)
47
+ norm_type = float(norm_type)
48
+ if len(grads) == 0:
49
+ return torch.tensor(0.0)
50
+ first_device = grads[0].device
51
+ grouped_grads: Dict[
52
+ Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
53
+ ] = _group_tensors_by_device_and_dtype(
54
+ [grads]
55
+ ) # type: ignore[assignment]
56
+
57
+ norms: List[Tensor] = []
58
+ for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
59
+ if (foreach is None and _has_foreach_support(device_grads, device)) or (
60
+ foreach and _device_has_foreach_support(device)
61
+ ):
62
+ norms.extend(torch._foreach_norm(device_grads, norm_type))
63
+ elif foreach:
64
+ raise RuntimeError(
65
+ f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
66
+ )
67
+ else:
68
+ norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
69
+
70
+ total_norm = torch.linalg.vector_norm(
71
+ torch.stack([norm.to(first_device) for norm in norms]), norm_type
72
+ )
73
+
74
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
75
+ raise RuntimeError(
76
+ f"The total norm of order {norm_type} for gradients from "
77
+ "`parameters` is non-finite, so it cannot be clipped. To disable "
78
+ "this error and scale the gradients by the non-finite norm anyway, "
79
+ "set `error_if_nonfinite=False`"
80
+ )
81
+ clip_coef = max_norm / (total_norm + 1e-6)
82
+ # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
83
+ # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
84
+ # when the gradients do not reside in CPU memory.
85
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
86
+ for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
87
+ if (foreach is None and _has_foreach_support(device_grads, device)) or (
88
+ foreach and _device_has_foreach_support(device)
89
+ ):
90
+ torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
91
+ elif foreach:
92
+ raise RuntimeError(
93
+ f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
94
+ )
95
+ else:
96
+ clip_coef_clamped_device = clip_coef_clamped.to(device)
97
+ for g in device_grads:
98
+ g.mul_(clip_coef_clamped_device)
99
+
100
+ return total_norm
bytelatent/train.py CHANGED
@@ -47,6 +47,7 @@ from bytelatent.eval import EVAL_FOLDER_NAME, launch_eval
47
  from bytelatent.logger import init_logger
48
  from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
49
  from bytelatent.model.blt import ByteLatentTransformer
 
50
  from bytelatent.optim import build_optimizer
51
  from bytelatent.probe import AutoProbeD
52
  from bytelatent.profiling import maybe_run_profiler
@@ -147,9 +148,26 @@ def validate_train_args(args: TrainArgs, output_size: int):
147
  * args.distributed.tp_size
148
  != get_world_size()
149
  ):
 
150
  assert get_world_size() % args.distributed.dp_shard == 0
 
 
 
 
 
 
 
 
 
 
151
  args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
152
 
 
 
 
 
 
 
153
  assert args.distributed.dp_replicate % args.distributed.tp_size == 0
154
  args.distributed.dp_replicate = (
155
  args.distributed.dp_replicate // args.distributed.tp_size
@@ -470,9 +488,20 @@ def train(args: TrainArgs):
470
  # For logging we undo that scaling
471
  loss = loss.detach() * args.grad_acc_steps
472
 
473
- grad_norm = torch.nn.utils.clip_grad_norm_(
474
- model.parameters(), max_norm=args.optim.clip, foreach=True
475
- )
 
 
 
 
 
 
 
 
 
 
 
476
 
477
  grad_norm = (
478
  grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm
 
47
  from bytelatent.logger import init_logger
48
  from bytelatent.metrics import GPUMemoryMonitor, MetricLogger, get_num_params
49
  from bytelatent.model.blt import ByteLatentTransformer
50
+ from bytelatent.norms import fixed_clip_grad_norm_
51
  from bytelatent.optim import build_optimizer
52
  from bytelatent.probe import AutoProbeD
53
  from bytelatent.profiling import maybe_run_profiler
 
148
  * args.distributed.tp_size
149
  != get_world_size()
150
  ):
151
+ logging.info("Modifying TrainArgs distributed config")
152
  assert get_world_size() % args.distributed.dp_shard == 0
153
+ logging.info("World size: %s", get_world_size())
154
+ logging.info(
155
+ "Existing setting: train_args.distributed.dp_shard=%s",
156
+ args.distributed.dp_shard,
157
+ )
158
+ logging.info(
159
+ "Setting train_args.distributed.dp_replicate=%s, was dp_replicate=%s",
160
+ get_world_size() // args.distributed.dp_shard,
161
+ args.distributed.dp_replicate,
162
+ )
163
  args.distributed.dp_replicate = get_world_size() // args.distributed.dp_shard
164
 
165
+ logging.info(
166
+ "Changing dp_replicate from %s to %s, to account for tp_size=%s",
167
+ args.distributed.dp_replicate,
168
+ args.distributed.dp_replicate // args.distributed.tp_size,
169
+ args.distributed.tp_size,
170
+ )
171
  assert args.distributed.dp_replicate % args.distributed.tp_size == 0
172
  args.distributed.dp_replicate = (
173
  args.distributed.dp_replicate // args.distributed.tp_size
 
488
  # For logging we undo that scaling
489
  loss = loss.detach() * args.grad_acc_steps
490
 
491
+ world_size = get_world_size()
492
+ if 1 < world_size <= 8:
493
+ # For some reason, there are errors in reduces due to
494
+ # not working for non-bf16 numbers. This function is a patched
495
+ # version that converts gradients to bf16 before computing norms.
496
+ # The error only happens in distributed training on one node,
497
+ # hence the guard
498
+ grad_norm = fixed_clip_grad_norm_(
499
+ model.parameters(), max_norm=args.optim.clip, foreach=True
500
+ )
501
+ else:
502
+ grad_norm = torch.nn.utils.clip_grad_norm_(
503
+ model.parameters(), max_norm=args.optim.clip, foreach=True
504
+ )
505
 
506
  grad_norm = (
507
  grad_norm.full_tensor() if isinstance(grad_norm, DTensor) else grad_norm