par-meta commited on
Commit
fe45f69
·
unverified ·
1 Parent(s): aebdc48

Add bpb and n_bytes to metric logging (#41)

Browse files
.gitignore CHANGED
@@ -167,3 +167,4 @@ figures/
167
  .DS_Store
168
  internal/
169
  jobs_parallel-copy/
 
 
167
  .DS_Store
168
  internal/
169
  jobs_parallel-copy/
170
+ wandb/
bytelatent/distributed.py CHANGED
@@ -127,6 +127,16 @@ def dist_max(x: Union[int, float], mesh: DeviceMesh = None):
127
  return tensor
128
 
129
 
 
 
 
 
 
 
 
 
 
 
130
  def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
131
  tensor = torch.tensor(x).cuda()
132
  dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
@@ -236,7 +246,7 @@ def setup_env(env_args: EnvironmentArgs):
236
  logger.warning(f"WARNING: Setting {name} to {value}")
237
 
238
 
239
- def setup_torch_distributed(dist_args):
240
  """
241
  Handle single and multi-GPU / multi-node / SLURM jobs.
242
  Initialize the following variables:
@@ -388,14 +398,14 @@ def clean_env():
388
 
389
 
390
  def parallelize_model(
391
- model,
392
  device_mesh,
393
  model_args,
394
  distributed_args: DistributedArgs,
395
  fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
396
  tp_parallelize=None,
397
  no_recompute_ops=None,
398
- ):
399
  if distributed_args.tp_size > 1:
400
  assert (
401
  distributed_args.fsdp_type == "full_shard"
 
127
  return tensor
128
 
129
 
130
+ def dist_sum(
131
+ x: Union[int, float], mesh: DeviceMesh = None, reduce_dtype: torch.dtype = None
132
+ ):
133
+ tensor = torch.tensor(x).cuda()
134
+ if reduce_dtype is not None:
135
+ tensor = tensor.to(reduce_dtype)
136
+ dist.all_reduce(tensor, op=ReduceOp.SUM, group=mesh.get_group() if mesh else None)
137
+ return tensor
138
+
139
+
140
  def dist_mean(x: Union[int, float], mesh: DeviceMesh = None):
141
  tensor = torch.tensor(x).cuda()
142
  dist.all_reduce(tensor, op=ReduceOp.AVG, group=mesh.get_group() if mesh else None)
 
246
  logger.warning(f"WARNING: Setting {name} to {value}")
247
 
248
 
249
+ def setup_torch_distributed(dist_args: DistributedArgs):
250
  """
251
  Handle single and multi-GPU / multi-node / SLURM jobs.
252
  Initialize the following variables:
 
398
 
399
 
400
  def parallelize_model(
401
+ model: torch.nn.Module,
402
  device_mesh,
403
  model_args,
404
  distributed_args: DistributedArgs,
405
  fsdp_grouping_plan: Optional[List[Tuple[str, bool]]] = None,
406
  tp_parallelize=None,
407
  no_recompute_ops=None,
408
+ ) -> torch.nn.Module:
409
  if distributed_args.tp_size > 1:
410
  assert (
411
  distributed_args.fsdp_type == "full_shard"
bytelatent/metrics.py CHANGED
@@ -49,7 +49,6 @@ class LoggingArgs(BaseModel):
49
  model_config = ConfigDict(extra="forbid")
50
  freq: int = 10 # Log every freq optimizer steps
51
  acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
52
-
53
  wandb: WandbArgs | None = None
54
 
55
 
 
49
  model_config = ConfigDict(extra="forbid")
50
  freq: int = 10 # Log every freq optimizer steps
51
  acc_freq: int | None = None # Log every acc_freq gradient accumulation steps
 
52
  wandb: WandbArgs | None = None
53
 
54
 
bytelatent/train.py CHANGED
@@ -3,6 +3,7 @@
3
 
4
  import gc
5
  import logging
 
6
  import os
7
  import sys
8
  from contextlib import ExitStack
@@ -11,6 +12,7 @@ from dataclasses import asdict, dataclass
11
  from timeit import default_timer as timer
12
  from typing import Any, TypeVar
13
 
 
14
  import torch
15
  import torch.distributed
16
  import torch.nn.functional
@@ -32,7 +34,9 @@ from bytelatent.data.iterators.packing_iterator import PackingIteratorState
32
  from bytelatent.distributed import (
33
  check_model_value_range,
34
  clean_env,
 
35
  dist_mean_dict,
 
36
  get_device_mesh,
37
  get_is_master,
38
  get_world_size,
@@ -392,6 +396,9 @@ def train(args: TrainArgs):
392
  time_last_log = timer()
393
  gc.collect()
394
  saved = False
 
 
 
395
  while train_state.step < args.steps and (
396
  args.max_steps is None or train_state.step < args.max_steps
397
  ):
@@ -413,6 +420,21 @@ def train(args: TrainArgs):
413
  batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
414
  mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  if (
417
  not args.train_entropy_model
418
  and args.model.encoder_enable_byte_ngrams
@@ -487,7 +509,7 @@ def train(args: TrainArgs):
487
  batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
488
  )
489
 
490
- loss, _ = compute_loss(pred, batch_y, mask, train_state.scale)
491
 
492
  # We scale loss with grad_acc_steps so the gradient is the same
493
  # regardless of grad_acc_steps
@@ -498,6 +520,10 @@ def train(args: TrainArgs):
498
  # For logging we undo that scaling
499
  loss = loss.detach() * args.grad_acc_steps
500
 
 
 
 
 
501
  world_size = get_world_size()
502
  if 1 < world_size <= 8:
503
  # For some reason, there are errors in reduces due to
@@ -568,50 +594,108 @@ def train(args: TrainArgs):
568
  * wps
569
  )
570
 
571
- metrics = flatten_dict(
572
- {
573
- "global_step": train_state.step,
574
- "acc_step": train_state.acc_step,
575
- "speed": {
576
- "wps": wps,
577
- "FLOPS": FLOPS,
578
- "curr_iter_time": curr_iter_time,
579
- "data_load_time": data_load_time,
580
- },
581
- "optim": {
582
- "grad_norm": grad_norm,
583
- "lr": curr_lr,
584
- "total_tokens": total_tokens,
585
- },
586
- "memory": gpu_mem_stats._asdict(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  },
 
 
 
 
 
 
 
 
 
 
 
 
588
  sep="/",
589
  )
590
 
591
- to_sync = {}
592
- to_sync["loss/out"] = loss.item()
593
- metrics.update(dist_mean_dict(to_sync))
594
-
595
  if get_is_master():
596
  metric_logger.log(metrics)
597
 
598
- gpu_memory_monitor.reset_peak_stats()
599
- nwords_since_last_log = 0
600
- time_last_log = timer()
 
 
601
  logger.info(
602
  f"step: {train_state.step}"
603
  f" acc: {train_state.acc_step}"
604
- f" loss: {round(loss.item(),4):>7}"
 
 
 
605
  f" grad: {grad_norm:.2e}"
606
  f" flops: {FLOPS:.2e}"
607
  f" wps: {wps:.2e}"
608
  f" iter: {curr_iter_time:>7}"
609
  f" data: {data_load_time:>5}"
610
  f" lr: {curr_lr:.2e}"
 
 
611
  f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
612
  f" pow: {gpu_mem_stats.power_draw/1000} W"
613
  )
614
 
 
 
 
 
 
 
 
615
  if every_n_steps(
616
  train_state, args.checkpoint.dump.every, acc_step=0
617
  ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):
 
3
 
4
  import gc
5
  import logging
6
+ import math
7
  import os
8
  import sys
9
  from contextlib import ExitStack
 
12
  from timeit import default_timer as timer
13
  from typing import Any, TypeVar
14
 
15
+ import numpy as np
16
  import torch
17
  import torch.distributed
18
  import torch.nn.functional
 
34
  from bytelatent.distributed import (
35
  check_model_value_range,
36
  clean_env,
37
+ dist_mean,
38
  dist_mean_dict,
39
+ dist_sum,
40
  get_device_mesh,
41
  get_is_master,
42
  get_world_size,
 
396
  time_last_log = timer()
397
  gc.collect()
398
  saved = False
399
+ step_losses: list[float] = []
400
+ step_tok_losses: list[float] = []
401
+ n_bytes: int = 0
402
  while train_state.step < args.steps and (
403
  args.max_steps is None or train_state.step < args.max_steps
404
  ):
 
420
  batch_patch_lengths = torch.from_numpy(batch.patch_lengths).cuda()
421
  mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
422
 
423
+ if args.data.tokenizer_args.name in ["bytes", "blt"]:
424
+ n_bytes += batch_y.numel() if mask is None else mask.sum()
425
+ elif args.data.tokenizer_args.name in ["sp", "tiktoken"]:
426
+ for example in batch.y:
427
+ target_tokens = tokenizer.decode(example.tolist(), cut_at_eos=False)
428
+ n_bytes += (
429
+ len(bytes(target_tokens, encoding="utf-8", errors="ignore"))
430
+ + sum(example == tokenizer.eos_id)
431
+ + sum(example == tokenizer.bos_id)
432
+ )
433
+ else:
434
+ raise ValueError(
435
+ f"Unexpected tokenizer to count n_bytes for: {args.data.tokenizer_args.name}"
436
+ )
437
+
438
  if (
439
  not args.train_entropy_model
440
  and args.model.encoder_enable_byte_ngrams
 
509
  batch_x, patch_lengths=batch_patch_lengths, ngram_ids=ngram_ids
510
  )
511
 
512
+ loss, tok_loss = compute_loss(pred, batch_y, mask, train_state.scale)
513
 
514
  # We scale loss with grad_acc_steps so the gradient is the same
515
  # regardless of grad_acc_steps
 
520
  # For logging we undo that scaling
521
  loss = loss.detach() * args.grad_acc_steps
522
 
523
+ # Undo loss scaling so downstream down't need to worry about it
524
+ step_losses.append((loss / train_state.scale).item())
525
+ step_tok_losses.append(tok_loss / train_state.scale)
526
+
527
  world_size = get_world_size()
528
  if 1 < world_size <= 8:
529
  # For some reason, there are errors in reduces due to
 
594
  * wps
595
  )
596
 
597
+ # Below, semantics are:
598
+ # per_gpu: Metrics on a given rank
599
+ # across_gpus: Metrics averaged/summed across all ranks
600
+ # step: Metric at a step
601
+ # interval: Metric averaged/summed across all steps since the last log interval.
602
+ # Typically, this is 10
603
+ step_loss_per_gpu = loss.item()
604
+ step_loss_across_gpus = dist_mean(step_loss_per_gpu).item()
605
+ interval_loss_per_gpu = np.mean(step_losses).item()
606
+ interval_loss_across_gpus = dist_mean(interval_loss_per_gpu).item()
607
+
608
+ stacked_tok_loss = torch.cat(step_tok_losses, dim=0)
609
+ interval_total_tok_loss_per_gpu = stacked_tok_loss.sum().item()
610
+ interval_total_tok_loss_across_gpus = dist_sum(
611
+ interval_total_tok_loss_per_gpu, reduce_dtype=torch.bfloat16
612
+ ).item()
613
+ interval_total_n_bytes_per_gpu = n_bytes
614
+ interval_total_n_bytes_across_gpus = dist_sum(
615
+ n_bytes, reduce_dtype=torch.bfloat16
616
+ ).item()
617
+
618
+ interval_bpb_per_gpu = (
619
+ interval_total_tok_loss_per_gpu
620
+ / math.log(2)
621
+ / interval_total_n_bytes_per_gpu
622
+ )
623
+ interval_bpb_across_gpus = (
624
+ interval_total_tok_loss_across_gpus
625
+ / math.log(2)
626
+ / interval_total_n_bytes_across_gpus
627
+ )
628
+
629
+ metric_dict = {
630
+ "global_step": train_state.step,
631
+ "acc_step": train_state.acc_step,
632
+ "speed": {
633
+ "wps": wps,
634
+ "FLOPS": FLOPS,
635
+ "curr_iter_time": curr_iter_time,
636
+ "data_load_time": data_load_time,
637
+ },
638
+ "optim": {
639
+ "grad_norm": grad_norm,
640
+ "lr": curr_lr,
641
+ "total_tokens": total_tokens,
642
+ },
643
+ "memory": gpu_mem_stats._asdict(),
644
+ "loss": {
645
+ "step_per_gpu": step_loss_per_gpu,
646
+ "step_across_gpu": step_loss_across_gpus,
647
+ "interval_per_gpu": interval_loss_per_gpu,
648
+ "interval_across_gpu": interval_loss_across_gpus,
649
  },
650
+ "bpb": {
651
+ "interval_per_gpu": interval_bpb_per_gpu,
652
+ "interval_across_gpus": interval_bpb_across_gpus,
653
+ },
654
+ "n_bytes": {
655
+ "interval_per_gpu": interval_total_n_bytes_per_gpu,
656
+ "interval_across_gpus": interval_total_n_bytes_across_gpus,
657
+ },
658
+ }
659
+
660
+ metrics = flatten_dict(
661
+ metric_dict,
662
  sep="/",
663
  )
664
 
 
 
 
 
665
  if get_is_master():
666
  metric_logger.log(metrics)
667
 
668
+ # Below semantics are:
669
+ # step=Metrics at a step
670
+ # interval=Metrics averaged across the logging interval
671
+ # local=On one rank
672
+ # global=Across all ranks
673
  logger.info(
674
  f"step: {train_state.step}"
675
  f" acc: {train_state.acc_step}"
676
+ f" loss_gpu: {round(interval_loss_per_gpu, 4):>7}"
677
+ f" loss_avg: {round(interval_loss_across_gpus, 4):>7}"
678
+ f" bpb_gpu: {interval_bpb_per_gpu:3f}"
679
+ f" bpb_avg: {interval_bpb_across_gpus:3f}"
680
  f" grad: {grad_norm:.2e}"
681
  f" flops: {FLOPS:.2e}"
682
  f" wps: {wps:.2e}"
683
  f" iter: {curr_iter_time:>7}"
684
  f" data: {data_load_time:>5}"
685
  f" lr: {curr_lr:.2e}"
686
+ f" n_bytes_gpu: {int(interval_total_n_bytes_per_gpu)}"
687
+ f" n_bytes_sum: {int(interval_total_n_bytes_across_gpus)}"
688
  f" mem: {gpu_mem_stats.max_active_pct:.0f}%"
689
  f" pow: {gpu_mem_stats.power_draw/1000} W"
690
  )
691
 
692
+ n_bytes = 0
693
+ step_losses = []
694
+ step_tok_losses = []
695
+ gpu_memory_monitor.reset_peak_stats()
696
+ nwords_since_last_log = 0
697
+ time_last_log = timer()
698
+
699
  if every_n_steps(
700
  train_state, args.checkpoint.dump.every, acc_step=0
701
  ) or every_n_steps(train_state, args.checkpoint.eval.every, acc_step=0):