par-meta commited on
Commit
6ffeb66
·
unverified ·
1 Parent(s): caec8d2

Changes for training entropy model and correcting attention in local models (#25)

Browse files

Summary:

- Refactor local model configs to be separate and clearer
- Add attention arguments and correct which attention is used in local models
- Preparation for being able to have an entropy train script
- Fix failing unit tests

Test Plan:

bytelatent/args.py CHANGED
@@ -30,6 +30,7 @@ from bytelatent.model.blt import ByteLatentTransformerArgs
30
  from bytelatent.optim import OptimArgs
31
  from bytelatent.profiling import ProfilerArgs
32
  from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
 
33
 
34
  logger = logging.getLogger()
35
 
@@ -163,6 +164,8 @@ class TrainArgs(BaseModel):
163
 
164
  seed: int = 42
165
 
 
 
166
  # Number of gradient accumulation steps
167
  # Total batch size is batch_size*grad_acc_steps
168
  grad_acc_steps: int = 1
@@ -176,6 +179,10 @@ class TrainArgs(BaseModel):
176
  data: DataloaderArgs = DataloaderArgs()
177
  optim: OptimArgs = OptimArgs()
178
  model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
 
 
 
 
179
  distributed: DistributedArgs = DistributedArgs()
180
  env: EnvironmentArgs = EnvironmentArgs()
181
 
 
30
  from bytelatent.optim import OptimArgs
31
  from bytelatent.profiling import ProfilerArgs
32
  from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
33
+ from bytelatent.transformer import LMTransformerArgs
34
 
35
  logger = logging.getLogger()
36
 
 
164
 
165
  seed: int = 42
166
 
167
+ debug_dynamo: bool = False
168
+
169
  # Number of gradient accumulation steps
170
  # Total batch size is batch_size*grad_acc_steps
171
  grad_acc_steps: int = 1
 
179
  data: DataloaderArgs = DataloaderArgs()
180
  optim: OptimArgs = OptimArgs()
181
  model: ByteLatentTransformerArgs = ByteLatentTransformerArgs()
182
+ # This is only needed for training the entropy model
183
+ entropy_model: LMTransformerArgs | None = None
184
+ # Instead of training main model, train entropy model
185
+ train_entropy_model: bool = False
186
  distributed: DistributedArgs = DistributedArgs()
187
  env: EnvironmentArgs = EnvironmentArgs()
188
 
bytelatent/base_transformer.py CHANGED
@@ -4,7 +4,7 @@ from enum import Enum
4
  from typing import Optional, Tuple, Union
5
 
6
  import torch
7
- from pydantic import BaseModel
8
  from torch import nn
9
  from torch.nn import functional as F
10
  from torch.nn.attention.flex_attention import (
@@ -15,6 +15,7 @@ from torch.nn.attention.flex_attention import (
15
  from xformers.ops import AttentionBias, fmha
16
 
17
  from bytelatent import probe
 
18
 
19
  if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
20
  flex_attention_comp = torch.compile(flex_attention)
@@ -30,13 +31,14 @@ class InitStdFactor(Enum):
30
 
31
 
32
  class BaseTransformerArgs(BaseModel):
 
33
  dim: int = 512
34
  n_layers: int = 8
35
- head_dim: Optional[int] = None
36
- n_heads: Optional[int] = None
37
- n_kv_heads: Optional[int] = None
38
 
39
- ffn_dim_multiplier: Optional[float] = None
40
 
41
  multiple_of: int = 256
42
 
@@ -44,11 +46,16 @@ class BaseTransformerArgs(BaseModel):
44
 
45
  rope_theta: float = 10000.0
46
 
47
- init_base_std: Optional[float] = None
48
  init_std_factor: InitStdFactor = InitStdFactor.DISABLED
49
 
50
  max_seqlen: int = 1024
51
 
 
 
 
 
 
52
 
53
  def cross_entropy(pred, target, **kwargs):
54
  return F.nll_loss(
@@ -294,6 +301,18 @@ class RMSNorm(nn.Module):
294
  torch.nn.init.ones_(self.weight) # type: ignore
295
 
296
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  class Attention(nn.Module):
298
  def __init__(
299
  self,
@@ -371,9 +390,12 @@ class Attention(nn.Module):
371
  output = flex_attention_comp(xq, xk, xv, block_mask=mask)
372
  output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
373
 
374
- elif attn_impl == "fmha":
375
  assert mask is None or isinstance(mask, AttentionBias)
 
 
376
  output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
 
377
  # This uses B S H D instead of B H S D of pytorch
378
 
379
  elif attn_impl == "sdpa":
@@ -522,14 +544,16 @@ class TransformerBlock(nn.Module):
522
  mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
523
  attn_impl: str = "sdpa",
524
  ) -> torch.Tensor:
525
- h = x + self.attention(
526
  self.attention_norm(x),
527
  freq_cis,
528
  tok_idx=tok_idx,
529
  mask=mask,
530
  attn_impl=attn_impl,
531
  )
532
- out = h + self.feed_forward(self.ffn_norm(h))
 
 
533
  return out
534
 
535
  def init_weights(self, init_std=None, factor=1.0):
@@ -545,6 +569,8 @@ class BaseTransformer(nn.Module):
545
  super().__init__()
546
  self.dim = args.dim
547
  self.init_base_std = args.init_base_std
 
 
548
  self.init_std_factor = InitStdFactor(args.init_std_factor)
549
  self.max_seqlen = args.max_seqlen
550
  self.rope_embeddings = RotaryEmbedding(
@@ -552,6 +578,7 @@ class BaseTransformer(nn.Module):
552
  head_dim=args.head_dim or args.dim // args.n_heads,
553
  max_seqlen=args.max_seqlen,
554
  )
 
555
 
556
  self.layers = nn.ModuleList()
557
  for _ in range(args.n_layers):
 
4
  from typing import Optional, Tuple, Union
5
 
6
  import torch
7
+ from pydantic import BaseModel, ConfigDict
8
  from torch import nn
9
  from torch.nn import functional as F
10
  from torch.nn.attention.flex_attention import (
 
15
  from xformers.ops import AttentionBias, fmha
16
 
17
  from bytelatent import probe
18
+ from bytelatent.tokenizers.constants import EOS_ID
19
 
20
  if int(os.environ.get("BLT_ALLOW_MISSING_FLEX_ATTENTION", False)) == 0:
21
  flex_attention_comp = torch.compile(flex_attention)
 
31
 
32
 
33
  class BaseTransformerArgs(BaseModel):
34
+ model_config = ConfigDict(extra="forbid")
35
  dim: int = 512
36
  n_layers: int = 8
37
+ head_dim: int | None = None
38
+ n_heads: int | None = None
39
+ n_kv_heads: int | None = None
40
 
41
+ ffn_dim_multiplier: float | None = None
42
 
43
  multiple_of: int = 256
44
 
 
46
 
47
  rope_theta: float = 10000.0
48
 
49
+ init_base_std: float | None = None
50
  init_std_factor: InitStdFactor = InitStdFactor.DISABLED
51
 
52
  max_seqlen: int = 1024
53
 
54
+ attn_impl: str | None = "sdpa"
55
+ attn_bias_type: str | None = None
56
+ # Special token config
57
+ eos_id: int | None = EOS_ID
58
+
59
 
60
  def cross_entropy(pred, target, **kwargs):
61
  return F.nll_loss(
 
301
  torch.nn.init.ones_(self.weight) # type: ignore
302
 
303
 
304
+ def _reshape_for_attn_bias(
305
+ attn_bias: AttentionBias | None,
306
+ *tensors: torch.Tensor,
307
+ ) -> list[torch.Tensor]:
308
+ to_transform = list(tensors)
309
+ if isinstance(attn_bias, fmha.attn_bias.BlockDiagonalCausalMask):
310
+ # could be `view` instead of reshape during training, but for inference
311
+ # have to reshape due to strides mismatch
312
+ to_transform = [t.reshape(1, -1, *t.shape[2:]) for t in to_transform]
313
+ return to_transform
314
+
315
+
316
  class Attention(nn.Module):
317
  def __init__(
318
  self,
 
390
  output = flex_attention_comp(xq, xk, xv, block_mask=mask)
391
  output = output.transpose(1, 2).contiguous() # B H S D -> B S H D
392
 
393
+ elif attn_impl == "xformers":
394
  assert mask is None or isinstance(mask, AttentionBias)
395
+ query_shape = xq.shape
396
+ xq, xk, xv = _reshape_for_attn_bias(mask, xq, xk, xv)
397
  output = fmha.memory_efficient_attention(xq, xk, xv, attn_bias=mask)
398
+ output = output.view(query_shape)
399
  # This uses B S H D instead of B H S D of pytorch
400
 
401
  elif attn_impl == "sdpa":
 
544
  mask: Optional[Union[BlockMask, AttentionBias, str]] = None,
545
  attn_impl: str = "sdpa",
546
  ) -> torch.Tensor:
547
+ attn_out = self.attention(
548
  self.attention_norm(x),
549
  freq_cis,
550
  tok_idx=tok_idx,
551
  mask=mask,
552
  attn_impl=attn_impl,
553
  )
554
+ h = x + attn_out
555
+ h_norm = self.ffn_norm(h)
556
+ out = h + self.feed_forward(h_norm)
557
  return out
558
 
559
  def init_weights(self, init_std=None, factor=1.0):
 
569
  super().__init__()
570
  self.dim = args.dim
571
  self.init_base_std = args.init_base_std
572
+ self.attn_impl = args.attn_impl
573
+ self.attn_bias_type = args.attn_bias_type
574
  self.init_std_factor = InitStdFactor(args.init_std_factor)
575
  self.max_seqlen = args.max_seqlen
576
  self.rope_embeddings = RotaryEmbedding(
 
578
  head_dim=args.head_dim or args.dim // args.n_heads,
579
  max_seqlen=args.max_seqlen,
580
  )
581
+ self.eos_id = args.eos_id
582
 
583
  self.layers = nn.ModuleList()
584
  for _ in range(args.n_layers):
bytelatent/configs/debug.yaml CHANGED
@@ -15,7 +15,6 @@ optim:
15
 
16
  distributed:
17
  fsdp_type: full_shard
18
- compile: true
19
  model_dtype: bf16
20
  matmul_allow_tf32: false
21
  selective_activation_checkpointing: false
@@ -58,13 +57,13 @@ model:
58
  recompute_attn: false
59
  custom_bwd: false
60
  layer_ckpt: "none"
61
- efficient_attn: "sdpa"
62
  patch_only_encoder: false
63
  patch_only_decoder: false
64
  use_local_encoder_transformer: true
65
  init_use_gaussian: true
66
  init_use_depth: "current"
67
  attn_bias_type: "block_causal"
 
68
  alpha_depth: "disabled"
69
  max_length: 256
70
  local_attention_window_len: 512
 
15
 
16
  distributed:
17
  fsdp_type: full_shard
 
18
  model_dtype: bf16
19
  matmul_allow_tf32: false
20
  selective_activation_checkpointing: false
 
57
  recompute_attn: false
58
  custom_bwd: false
59
  layer_ckpt: "none"
 
60
  patch_only_encoder: false
61
  patch_only_decoder: false
62
  use_local_encoder_transformer: true
63
  init_use_gaussian: true
64
  init_use_depth: "current"
65
  attn_bias_type: "block_causal"
66
+ attn_impl: "xformers"
67
  alpha_depth: "disabled"
68
  max_length: 256
69
  local_attention_window_len: 512
bytelatent/data/iterators/test_arrow_iterator.py CHANGED
@@ -27,6 +27,7 @@ def test_basic_arrow_file():
27
  dataset_files=[ARROW_TEST_DATA_1],
28
  row_num=0,
29
  arrow_batch_size=100,
 
30
  )
31
  arrow_file = initial_state.build()
32
  start_state = arrow_file.get_state()
@@ -55,6 +56,7 @@ def test_basic_arrow_file():
55
  dataset_files=[ARROW_TEST_DATA_1],
56
  row_num=251,
57
  arrow_batch_size=100,
 
58
  )
59
  arrow_file = resumed_state.build()
60
  for example in arrow_file.create_iter():
@@ -74,6 +76,7 @@ def test_basic_arrow_file():
74
  dataset_files=[ARROW_TEST_DATA_1],
75
  row_num=0,
76
  arrow_batch_size=100,
 
77
  )
78
  arrow_file = rank_state.build()
79
  expected_ids = []
 
27
  dataset_files=[ARROW_TEST_DATA_1],
28
  row_num=0,
29
  arrow_batch_size=100,
30
+ s3_profile=None,
31
  )
32
  arrow_file = initial_state.build()
33
  start_state = arrow_file.get_state()
 
56
  dataset_files=[ARROW_TEST_DATA_1],
57
  row_num=251,
58
  arrow_batch_size=100,
59
+ s3_profile=None,
60
  )
61
  arrow_file = resumed_state.build()
62
  for example in arrow_file.create_iter():
 
76
  dataset_files=[ARROW_TEST_DATA_1],
77
  row_num=0,
78
  arrow_batch_size=100,
79
+ s3_profile=None,
80
  )
81
  arrow_file = rank_state.build()
82
  expected_ids = []
bytelatent/distributed.py CHANGED
@@ -11,7 +11,6 @@ import socket
11
  import subprocess
12
  import sys
13
  import tempfile
14
- from dataclasses import asdict, dataclass
15
  from functools import lru_cache, partial, reduce
16
  from itertools import chain
17
  from typing import List, Optional, Tuple, Union
 
11
  import subprocess
12
  import sys
13
  import tempfile
 
14
  from functools import lru_cache, partial, reduce
15
  from itertools import chain
16
  from typing import List, Optional, Tuple, Union
bytelatent/entropy_model.py CHANGED
@@ -1,12 +1,14 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  import json
 
3
  import os
4
- import re
5
 
6
  import torch
7
 
8
  from bytelatent.transformer import LMTransformer, LMTransformerArgs
9
 
 
 
10
 
11
  def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
12
  with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
@@ -14,6 +16,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
14
 
15
  torch.set_default_dtype(torch.bfloat16)
16
  model_params = reloaded["model"]
 
 
 
17
  entropy_model = LMTransformer(
18
  LMTransformerArgs(
19
  dim=model_params["dim"],
@@ -22,6 +27,9 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
22
  max_seqlen=model_params["max_length"],
23
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
24
  vocab_size=model_params["vocab_size"],
 
 
 
25
  )
26
  )
27
 
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  import json
3
+ import logging
4
  import os
 
5
 
6
  import torch
7
 
8
  from bytelatent.transformer import LMTransformer, LMTransformerArgs
9
 
10
+ logger = logging.getLogger()
11
+
12
 
13
  def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cpu"):
14
  with open(os.path.join(entropy_model_checkpoint_dir, "params.json")) as fr:
 
16
 
17
  torch.set_default_dtype(torch.bfloat16)
18
  model_params = reloaded["model"]
19
+ logger.warning(
20
+ "Update checkpoint to load attn and sliding window args from checkpoint"
21
+ )
22
  entropy_model = LMTransformer(
23
  LMTransformerArgs(
24
  dim=model_params["dim"],
 
27
  max_seqlen=model_params["max_length"],
28
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
29
  vocab_size=model_params["vocab_size"],
30
+ attn_bias_type="local_block_causal",
31
+ attn_impl="xformers",
32
+ sliding_window=512,
33
  )
34
  )
35
 
bytelatent/model/blt.py CHANGED
@@ -15,8 +15,8 @@ from bytelatent.base_transformer import (
15
  TransformerBlock,
16
  )
17
  from bytelatent.data.patcher import Patcher, PatcherArgs
18
- from bytelatent.model.local_models import LocalDecoder, LocalEncoder
19
- from bytelatent.model.transformer import GlobalTransformer
20
  from bytelatent.model.utils import downsample
21
  from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
22
 
@@ -403,7 +403,6 @@ def patch_ids_from_lengths(patch_lengths, seq_len):
403
 
404
 
405
  class ByteLatentTransformerArgs(BaseTransformerArgs):
406
- model_config = ConfigDict(extra="forbid")
407
  # Basic model configuration
408
  seed: int = 42
409
  vocab_size: int = -1
@@ -412,7 +411,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
412
  n_heads: int = 8
413
  # TODO: What is the purpose of this parameter?
414
  weight_tying: bool = False
415
- sliding_window: Optional[int] = None
416
 
417
  # Architecture and dimensions
418
  dim_token: int = 256
@@ -471,11 +469,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
471
  recompute_attn: bool = True
472
  custom_bwd: bool = False
473
  layer_ckpt: str = "all"
474
- efficient_attn: str | None = None
475
-
476
- # Architecture options
477
- patch_only_encoder: bool = False
478
- patch_only_decoder: bool = False
479
 
480
  # Initialization and attention
481
  init_use_gaussian: bool = True
@@ -541,9 +534,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
541
  # Logging
542
  full_logging_n_layers: int = 4
543
 
544
- # Special token config
545
- eos_id: int | None = None
546
-
547
  @model_validator(mode="after")
548
  def check_hash_byte_sizes(self) -> Self:
549
  if (
@@ -558,22 +548,6 @@ class ByteLatentTransformerArgs(BaseTransformerArgs):
558
  return self
559
 
560
 
561
- class LocalEncoderArgs(ByteLatentTransformerArgs):
562
- # Local encoder specific dimensions
563
- n_heads_local_encoder: int = 8
564
- dim_token_emb: int | None = None
565
- dim_patch_emb: int | None = None
566
-
567
- def __post_init__(self):
568
- # Override base args with local encoder specific values
569
- self.dim = self.dim_local_encoder
570
- self.n_layers = self.n_layers_local_encoder
571
- self.n_heads = self.n_heads_local_encoder
572
- self.cross_attn_decoder = False
573
- self.cross_attn_k = self.cross_attn_k if self.cross_attn_encoder else None
574
- self.attn_bias_type = "local_block_causal"
575
-
576
-
577
  class GlobalTransformerArgs(ByteLatentTransformerArgs):
578
  # Global encoder specific dimensions
579
  dim_token_emb: int | None = None
@@ -625,20 +599,42 @@ def create_global_transformer(args: ByteLatentTransformerArgs) -> GlobalTransfor
625
 
626
 
627
  def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
628
- # First deep copy the original args
629
- # Replace with local encoder specific values
630
- local_encoder_args = args.model_copy(
631
- deep=True,
632
- update=dict(
633
- dim=args.dim_local_encoder,
634
- n_layers=args.n_layers_local_encoder,
635
- n_heads=args.n_heads_local_encoder,
636
- dim_token_emb=get_encoder_dim_token_emb(args),
637
- dim_patch_emb=get_encoder_dim_patch_emb(args),
638
- cross_attn_decoder=False,
639
- cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
640
- attn_bias_type="local_block_causal",
641
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
  )
643
 
644
  return LocalEncoder(local_encoder_args)
@@ -646,18 +642,41 @@ def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
646
 
647
  def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
648
  # First deep copy the original args
649
- local_decoder_args = args.model_copy(
650
- deep=True,
651
- update=dict(
652
- dim=args.dim_local_decoder,
653
- n_layers=args.n_layers_local_decoder,
654
- n_heads=args.n_heads_local_decoder,
655
- cross_attn_encoder=False,
656
- cross_attn_init_by_pooling=False, # states are already defined
657
- dim_token_emb=get_decoder_dim_token_emb(args),
658
- dim_patch_emb=args.dim_global,
659
- cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
660
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  )
662
 
663
  return LocalDecoder(local_decoder_args)
@@ -763,7 +782,6 @@ class ByteLatentTransformer(nn.Module):
763
 
764
  # General configuration
765
  self.weight_tying = args.weight_tying
766
- self.sliding_window = args.sliding_window
767
  self.patch_size = args.patch_size
768
  self.patching_mode = args.patching_mode
769
  self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
 
15
  TransformerBlock,
16
  )
17
  from bytelatent.data.patcher import Patcher, PatcherArgs
18
+ from bytelatent.model.latent_transformer import GlobalTransformer
19
+ from bytelatent.model.local_models import LocalDecoder, LocalEncoder, LocalModelArgs
20
  from bytelatent.model.utils import downsample
21
  from bytelatent.tokenizers.constants import BOE_ID, BOS_ID, EOS_ID, OFFSET, PAD_ID
22
 
 
403
 
404
 
405
  class ByteLatentTransformerArgs(BaseTransformerArgs):
 
406
  # Basic model configuration
407
  seed: int = 42
408
  vocab_size: int = -1
 
411
  n_heads: int = 8
412
  # TODO: What is the purpose of this parameter?
413
  weight_tying: bool = False
 
414
 
415
  # Architecture and dimensions
416
  dim_token: int = 256
 
469
  recompute_attn: bool = True
470
  custom_bwd: bool = False
471
  layer_ckpt: str = "all"
 
 
 
 
 
472
 
473
  # Initialization and attention
474
  init_use_gaussian: bool = True
 
534
  # Logging
535
  full_logging_n_layers: int = 4
536
 
 
 
 
537
  @model_validator(mode="after")
538
  def check_hash_byte_sizes(self) -> Self:
539
  if (
 
548
  return self
549
 
550
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  class GlobalTransformerArgs(ByteLatentTransformerArgs):
552
  # Global encoder specific dimensions
553
  dim_token_emb: int | None = None
 
599
 
600
 
601
  def create_local_encoder(args: ByteLatentTransformerArgs) -> LocalEncoder:
602
+ local_encoder_args = LocalModelArgs(
603
+ # Updated args
604
+ dim=args.dim_local_encoder,
605
+ n_layers=args.n_layers_local_encoder,
606
+ n_heads=args.n_heads_local_encoder,
607
+ dim_token_emb=get_encoder_dim_token_emb(args),
608
+ dim_patch_emb=get_encoder_dim_patch_emb(args),
609
+ cross_attn_encoder=args.cross_attn_encoder,
610
+ cross_attn_decoder=False,
611
+ cross_attn_k=args.cross_attn_k if args.cross_attn_encoder else None,
612
+ cross_attn_init_by_pooling=args.cross_attn_init_by_pooling,
613
+ # Defaults
614
+ head_dim=args.head_dim,
615
+ max_seqlen=args.max_encoder_seq_length,
616
+ dropout=args.dropout,
617
+ vocab_size=args.vocab_size + args.pm_size,
618
+ norm_eps=args.norm_eps,
619
+ patch_size=args.patch_size,
620
+ sliding_window=args.local_attention_window_len,
621
+ use_rope=args.use_rope,
622
+ rope_theta=args.rope_theta,
623
+ init_base_std=args.init_base_std,
624
+ init_std_factor=args.init_std_factor,
625
+ n_kv_heads=args.n_kv_heads,
626
+ attn_impl=args.attn_impl,
627
+ attn_bias_type="local_block_causal",
628
+ multiple_of=args.multiple_of,
629
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
630
+ patching_mode=args.patching_mode,
631
+ use_local_encoder_transformer=args.use_local_encoder_transformer,
632
+ downsampling_by_pooling=args.downsampling_by_pooling,
633
+ encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
634
+ cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
635
+ cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
636
+ cross_attn_nheads=args.cross_attn_nheads,
637
+ eos_id=args.eos_id,
638
  )
639
 
640
  return LocalEncoder(local_encoder_args)
 
642
 
643
  def create_local_decoder(args: ByteLatentTransformerArgs) -> LocalDecoder:
644
  # First deep copy the original args
645
+ local_decoder_args = LocalModelArgs(
646
+ dim=args.dim_local_decoder,
647
+ n_layers=args.n_layers_local_decoder,
648
+ n_heads=args.n_heads_local_decoder,
649
+ dim_token_emb=get_decoder_dim_token_emb(args),
650
+ dim_patch_emb=args.dim_global,
651
+ cross_attn_encoder=False,
652
+ cross_attn_decoder=args.cross_attn_decoder,
653
+ cross_attn_init_by_pooling=False, # states are already defined
654
+ cross_attn_k=args.cross_attn_k if args.cross_attn_decoder else None,
655
+ # Defaults
656
+ head_dim=args.head_dim,
657
+ max_seqlen=args.max_encoder_seq_length,
658
+ dropout=args.dropout,
659
+ vocab_size=args.vocab_size + args.pm_size,
660
+ norm_eps=args.norm_eps,
661
+ patch_size=args.patch_size,
662
+ sliding_window=args.local_attention_window_len,
663
+ use_rope=args.use_rope,
664
+ rope_theta=args.rope_theta,
665
+ init_base_std=args.init_base_std,
666
+ init_std_factor=args.init_std_factor,
667
+ n_kv_heads=args.n_kv_heads,
668
+ attn_impl=args.attn_impl,
669
+ attn_bias_type="local_block_causal",
670
+ multiple_of=args.multiple_of,
671
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
672
+ patching_mode=args.patching_mode,
673
+ use_local_encoder_transformer=args.use_local_encoder_transformer,
674
+ downsampling_by_pooling=args.downsampling_by_pooling,
675
+ encoder_hash_byte_group_size=args.encoder_hash_byte_group_size,
676
+ cross_attn_all_layers_encoder=args.cross_attn_all_layers_encoder,
677
+ cross_attn_all_layers_decoder=args.cross_attn_all_layers_decoder,
678
+ cross_attn_nheads=args.cross_attn_nheads,
679
+ eos_id=args.eos_id,
680
  )
681
 
682
  return LocalDecoder(local_decoder_args)
 
782
 
783
  # General configuration
784
  self.weight_tying = args.weight_tying
 
785
  self.patch_size = args.patch_size
786
  self.patching_mode = args.patching_mode
787
  self.boe_id, self.bos_id, self.pad_id, self.eos_id = (
bytelatent/model/{transformer.py → latent_transformer.py} RENAMED
@@ -11,6 +11,7 @@ from xformers.ops import AttentionBias
11
 
12
  from bytelatent.base_transformer import (
13
  BaseTransformer,
 
14
  RMSNorm,
15
  flex_attention_comp,
16
  repeat_kv,
@@ -142,11 +143,10 @@ class CrossAttention(nn.Module):
142
 
143
 
144
  class GlobalTransformer(BaseTransformer):
145
- def __init__(self, args):
146
  super().__init__(args)
147
  self.dropout = args.dropout
148
- self.sliding_window = args.sliding_window
149
- self.efficient_attn = args.efficient_attn
150
 
151
  self.token_embedding_projection = None
152
  if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
@@ -169,14 +169,19 @@ class GlobalTransformer(BaseTransformer):
169
  and projection to the token space.
170
  """
171
  bs, seqlen = tokens.shape
172
- attn_impl = self.efficient_attn
173
 
174
  h = embeds
175
 
176
  mask = (
177
  mask
178
  if mask is not None
179
- else create_causal_mask(seqlen, attn_impl, self.sliding_window)
 
 
 
 
 
 
180
  )
181
 
182
  if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
@@ -184,7 +189,7 @@ class GlobalTransformer(BaseTransformer):
184
 
185
  h = F.dropout(h, p=self.dropout, training=self.training)
186
 
187
- h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
188
  return h, cache
189
 
190
  def init_weights(self, init_base_std: float):
 
11
 
12
  from bytelatent.base_transformer import (
13
  BaseTransformer,
14
+ BaseTransformerArgs,
15
  RMSNorm,
16
  flex_attention_comp,
17
  repeat_kv,
 
143
 
144
 
145
  class GlobalTransformer(BaseTransformer):
146
+ def __init__(self, args: BaseTransformerArgs):
147
  super().__init__(args)
148
  self.dropout = args.dropout
149
+ self.eos_id = args.eos_id
 
150
 
151
  self.token_embedding_projection = None
152
  if args.dim_token_emb is not None and args.dim_token_emb != self.dim:
 
169
  and projection to the token space.
170
  """
171
  bs, seqlen = tokens.shape
 
172
 
173
  h = embeds
174
 
175
  mask = (
176
  mask
177
  if mask is not None
178
+ else create_causal_mask(
179
+ seqlen,
180
+ self.attn_impl,
181
+ self.attn_bias_type,
182
+ tokens=tokens,
183
+ eos_id=self.eos_id,
184
+ )
185
  )
186
 
187
  if self.token_embedding_projection is not None and h.shape[-1] != self.dim:
 
189
 
190
  h = F.dropout(h, p=self.dropout, training=self.training)
191
 
192
+ h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=self.attn_impl)
193
  return h, cache
194
 
195
  def init_weights(self, init_base_std: float):
bytelatent/model/local_models.py CHANGED
@@ -1,44 +1,75 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import logging
4
- from typing import List, Optional, Tuple, Union
5
 
6
  import torch
7
  import torch.nn
8
  import torch.nn as nn
 
9
  from torch.nn import functional as F
10
  from torch.nn.attention.flex_attention import BlockMask
11
  from xformers.ops import AttentionBias
12
 
13
  from bytelatent.base_transformer import (
 
14
  InitStdFactor,
15
  RMSNorm,
16
  RotaryEmbedding,
17
  TransformerBlock,
18
  )
19
- from bytelatent.model.transformer import CrossAttention
20
  from bytelatent.model.utils import create_causal_mask, downsample
21
  from bytelatent.tokenizers.blt_tokenizer import BOE_ID
22
 
23
  logger = logging.getLogger()
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class LocalModelBase(nn.Module):
27
- def __init__(self, args):
28
  super().__init__()
29
 
30
  self.dim = args.dim
31
  self.dropout = args.dropout
32
- self.vocab_size = args.vocab_size + args.pm_size
33
  self.patch_size = args.patch_size
34
 
35
- self.efficient_attn = args.efficient_attn
36
  self.sliding_window = args.sliding_window
37
  self.use_rope = args.use_rope
38
  self.init_std_factor = args.init_std_factor
39
  self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
40
  self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
41
  self.cross_attn_k = getattr(args, "cross_attn_k", None)
 
42
 
43
  self.boe_id = BOE_ID
44
 
@@ -54,7 +85,7 @@ class LocalModelBase(nn.Module):
54
  self.rope = RotaryEmbedding(
55
  theta=args.rope_theta,
56
  head_dim=args.head_dim or args.dim // args.n_heads,
57
- max_seqlen=getattr(args, "max_encoder_seq_length", args.max_length),
58
  )
59
  self.pos_embeddings = None
60
 
@@ -66,21 +97,15 @@ class LocalModelBase(nn.Module):
66
 
67
  self.patch_embedding_projection = self._create_patch_projection(args)
68
 
69
- def _should_create_patch_projection(self, args):
70
  dimension_mismatch = (
71
  getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
72
  )
73
 
74
  # Check cross attention conditions
75
  cross_attn_conditions = (
76
- hasattr(args, "cross_attn_encoder")
77
- and args.cross_attn_encoder
78
- and getattr(args, "cross_attn_init_by_pooling")
79
- ) or (
80
- hasattr(args, "cross_attn_decoder")
81
- and args.cross_attn_decoder
82
- and getattr(args, "cross_attn_init_by_pooling")
83
- )
84
 
85
  return dimension_mismatch or cross_attn_conditions
86
 
@@ -172,7 +197,7 @@ class LocalModelBase(nn.Module):
172
 
173
 
174
  class LocalEncoder(LocalModelBase):
175
- def __init__(self, args):
176
  super().__init__(args)
177
  self.output_proj = (
178
  args.patching_mode in ["entropy", "probmax"]
@@ -180,7 +205,6 @@ class LocalEncoder(LocalModelBase):
180
 
181
  self.apply_transformer = args.use_local_encoder_transformer
182
  self.downsampling_by_pooling = args.downsampling_by_pooling
183
- self.patch_only = args.patch_only_encoder
184
  self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
185
  self.cross_attn_encoder = args.cross_attn_encoder
186
  self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
@@ -224,7 +248,14 @@ class LocalEncoder(LocalModelBase):
224
  """ """
225
  bs, seqlen = tokens.shape
226
  if mask is None:
227
- mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
 
 
 
 
 
 
 
228
 
229
  h = self.apply_embedding(tokens, embeds)
230
  freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
@@ -232,7 +263,7 @@ class LocalEncoder(LocalModelBase):
232
  h = F.dropout(h, p=self.dropout, training=self.training)
233
 
234
  for i, layer in enumerate(self.layers):
235
- h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
236
  # check if cross attention should be applied to either all layer or only the last layer
237
  if self.cross_attn_encoder and (
238
  i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
@@ -273,12 +304,10 @@ class LocalEncoder(LocalModelBase):
273
 
274
 
275
  class LocalDecoder(LocalModelBase):
276
- def __init__(self, args):
277
  super().__init__(args)
278
 
279
  # Model configuration flags
280
- self.patch_only = args.patch_only_decoder
281
- self.expects_embeddings = args.share_encoder_decoder_emb
282
  self.cross_attn_decoder = args.cross_attn_decoder
283
  self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
284
  self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
@@ -317,7 +346,14 @@ class LocalDecoder(LocalModelBase):
317
  assert embeds is not None, "Embeddings must be provided"
318
 
319
  if mask is None:
320
- mask = create_causal_mask(seqlen, self.efficient_attn, self.sliding_window)
 
 
 
 
 
 
 
321
 
322
  h = embeds
323
 
@@ -347,7 +383,7 @@ class LocalDecoder(LocalModelBase):
347
  )
348
  h = h + h_cross
349
 
350
- h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.efficient_attn)
351
 
352
  h_preds = self.norm(h)
353
  h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import logging
4
+ from typing import Any, List, Optional, Tuple, Union
5
 
6
  import torch
7
  import torch.nn
8
  import torch.nn as nn
9
+ from pydantic import BaseModel, ConfigDict
10
  from torch.nn import functional as F
11
  from torch.nn.attention.flex_attention import BlockMask
12
  from xformers.ops import AttentionBias
13
 
14
  from bytelatent.base_transformer import (
15
+ BaseTransformerArgs,
16
  InitStdFactor,
17
  RMSNorm,
18
  RotaryEmbedding,
19
  TransformerBlock,
20
  )
21
+ from bytelatent.model.latent_transformer import CrossAttention
22
  from bytelatent.model.utils import create_causal_mask, downsample
23
  from bytelatent.tokenizers.blt_tokenizer import BOE_ID
24
 
25
  logger = logging.getLogger()
26
 
27
 
28
+ class LocalModelArgs(BaseTransformerArgs):
29
+ model_config = ConfigDict(extra="forbid")
30
+ # Override defaults
31
+ attn_impl: str | None = "xformers"
32
+ attn_bias_type: str | None = "local_block_causal"
33
+
34
+ # Local encoder specific dimensions
35
+ dropout: float
36
+ vocab_size: int
37
+ patch_size: int
38
+ sliding_window: int | None
39
+ use_rope: bool
40
+ cross_attn_encoder: bool | None
41
+ cross_attn_decoder: bool | None
42
+ cross_attn_k: int | None
43
+ cross_attn_init_by_pooling: bool
44
+ patching_mode: str
45
+ use_local_encoder_transformer: bool
46
+ downsampling_by_pooling: str | None
47
+ encoder_hash_byte_group_size: Any | None = None
48
+ cross_attn_all_layers_encoder: bool = False
49
+ cross_attn_all_layers_decoder: bool = False
50
+ cross_attn_nheads: int | None
51
+
52
+ dim_token_emb: int
53
+ dim_patch_emb: int | None
54
+
55
+
56
  class LocalModelBase(nn.Module):
57
+ def __init__(self, args: LocalModelArgs):
58
  super().__init__()
59
 
60
  self.dim = args.dim
61
  self.dropout = args.dropout
62
+ self.vocab_size = args.vocab_size
63
  self.patch_size = args.patch_size
64
 
65
+ self.attn_impl = args.attn_impl
66
  self.sliding_window = args.sliding_window
67
  self.use_rope = args.use_rope
68
  self.init_std_factor = args.init_std_factor
69
  self.cross_attn_encoder = getattr(args, "cross_attn_encoder", None)
70
  self.cross_attn_decoder = getattr(args, "cross_attn_decoder", None)
71
  self.cross_attn_k = getattr(args, "cross_attn_k", None)
72
+ self.eos_id = args.eos_id
73
 
74
  self.boe_id = BOE_ID
75
 
 
85
  self.rope = RotaryEmbedding(
86
  theta=args.rope_theta,
87
  head_dim=args.head_dim or args.dim // args.n_heads,
88
+ max_seqlen=args.max_seqlen,
89
  )
90
  self.pos_embeddings = None
91
 
 
97
 
98
  self.patch_embedding_projection = self._create_patch_projection(args)
99
 
100
+ def _should_create_patch_projection(self, args: LocalModelArgs):
101
  dimension_mismatch = (
102
  getattr(args, "dim_patch_emb") and args.dim_patch_emb != self.dim
103
  )
104
 
105
  # Check cross attention conditions
106
  cross_attn_conditions = (
107
+ args.cross_attn_encoder and args.cross_attn_init_by_pooling
108
+ ) or (args.cross_attn_decoder and args.cross_attn_init_by_pooling)
 
 
 
 
 
 
109
 
110
  return dimension_mismatch or cross_attn_conditions
111
 
 
197
 
198
 
199
  class LocalEncoder(LocalModelBase):
200
+ def __init__(self, args: LocalModelArgs):
201
  super().__init__(args)
202
  self.output_proj = (
203
  args.patching_mode in ["entropy", "probmax"]
 
205
 
206
  self.apply_transformer = args.use_local_encoder_transformer
207
  self.downsampling_by_pooling = args.downsampling_by_pooling
 
208
  self.expects_hash_embeddings = args.encoder_hash_byte_group_size is not None
209
  self.cross_attn_encoder = args.cross_attn_encoder
210
  self.cross_attn_all_layers_encoder = args.cross_attn_all_layers_encoder
 
248
  """ """
249
  bs, seqlen = tokens.shape
250
  if mask is None:
251
+ mask = create_causal_mask(
252
+ seqlen,
253
+ self.attn_impl,
254
+ "local_block_causal",
255
+ sliding_window=self.sliding_window,
256
+ tokens=tokens,
257
+ eos_id=self.eos_id,
258
+ )
259
 
260
  h = self.apply_embedding(tokens, embeds)
261
  freqs_cis = self.rope(seqlen=seqlen) if self.use_rope else None
 
263
  h = F.dropout(h, p=self.dropout, training=self.training)
264
 
265
  for i, layer in enumerate(self.layers):
266
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
267
  # check if cross attention should be applied to either all layer or only the last layer
268
  if self.cross_attn_encoder and (
269
  i == len(self.layers) - 1 or self.cross_attn_all_layers_encoder
 
304
 
305
 
306
  class LocalDecoder(LocalModelBase):
307
+ def __init__(self, args: LocalModelArgs):
308
  super().__init__(args)
309
 
310
  # Model configuration flags
 
 
311
  self.cross_attn_decoder = args.cross_attn_decoder
312
  self.cross_attn_all_layers_decoder = args.cross_attn_all_layers_decoder
313
  self.cross_attn_init_by_pooling = args.cross_attn_init_by_pooling
 
346
  assert embeds is not None, "Embeddings must be provided"
347
 
348
  if mask is None:
349
+ mask = create_causal_mask(
350
+ seqlen,
351
+ self.attn_impl,
352
+ "local_block_causal",
353
+ sliding_window=self.sliding_window,
354
+ tokens=tokens,
355
+ eos_id=self.eos_id,
356
+ )
357
 
358
  h = embeds
359
 
 
383
  )
384
  h = h + h_cross
385
 
386
+ h = layer(h, mask=mask, freq_cis=freqs_cis, attn_impl=self.attn_impl)
387
 
388
  h_preds = self.norm(h)
389
  h_preds = F.dropout(h_preds, p=self.dropout, training=self.training)
bytelatent/model/utils.py CHANGED
@@ -1,8 +1,13 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
 
 
 
2
  import torch
3
  from torch.nn.attention.flex_attention import create_block_mask
4
  from xformers.ops import fmha
5
 
 
 
6
 
7
  def patch_reduce(h, max_num_patches, reduction, patch_ids):
8
  """
@@ -97,15 +102,74 @@ def causal_mask(b, h, q_idx, kv_idx):
97
  return q_idx >= kv_idx
98
 
99
 
100
- def create_causal_mask(seqlen, attn_impl, sliding_window):
101
- if sliding_window is not None and attn_impl == "xformers":
102
- return fmha.attn_bias.LocalAttentionFromBottomRightMask(
103
- window_left=sliding_window - 1, window_right=0
104
- )
105
- elif attn_impl == "xformers":
106
- return fmha.attn_bias.LowerTriangularMask()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  elif attn_impl == "sdpa":
108
- return "causal"
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  elif attn_impl == "flex_attention":
110
  return create_block_mask(causal_mask, None, None, seqlen, seqlen)
111
  elif attn_impl == "fmha":
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ import logging
3
+ import os
4
+
5
  import torch
6
  from torch.nn.attention.flex_attention import create_block_mask
7
  from xformers.ops import fmha
8
 
9
+ logger = logging.getLogger()
10
+
11
 
12
  def patch_reduce(h, max_num_patches, reduction, patch_ids):
13
  """
 
102
  return q_idx >= kv_idx
103
 
104
 
105
+ def tokens_to_seqlen(batch: torch.Tensor, eos_id: int):
106
+ """
107
+ 0 0 0 1 0 0 0 1 0 0 0
108
+ 0 1 0 0 0 1 0 0 0 0 0
109
+ -> 4 4 3 2 4 5
110
+ """
111
+ mask = batch == eos_id
112
+ mask[:, -1] = True # virtual eos at the end of each row
113
+
114
+ # 0 0 0 1 0 0 0 1 0 0 X
115
+ # 0 1 0 0 0 1 0 0 0 0 X
116
+ row, col = torch.where(mask)
117
+
118
+ # row = 0, 0, 0, 1, 1, 1
119
+ # col = 3, 7, 10, 1, 5, 10
120
+ seqlens = (col[1:] - col[:-1]) + (row[1:] - row[:-1]) * mask.shape[1]
121
+ # seqlens = (4, 3, -9, 4, 5) + (0, 0, 11, 0, 0) = (4, 3, 2, 4, 5)
122
+ return [int(col[0].item() + 1)] + seqlens.tolist()
123
+
124
+
125
+ def create_causal_mask(
126
+ seqlen,
127
+ attn_impl: str,
128
+ attn_bias_type: str | None,
129
+ *,
130
+ eos_id: int | None = None,
131
+ tokens: torch.Tensor | None = None,
132
+ sliding_window: int | None = None,
133
+ ):
134
+ if attn_impl == "xformers":
135
+ if attn_bias_type is None:
136
+ return fmha.attn_bias.LowerTriangularMask()
137
+ elif attn_bias_type == "causal":
138
+ assert sliding_window is None
139
+ return fmha.attn_bias.LowerTriangularMask()
140
+ elif attn_bias_type == "block_causal":
141
+ assert sliding_window is None
142
+ assert eos_id is not None
143
+ assert tokens is not None
144
+ return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
145
+ q_seqlen=tokens_to_seqlen(tokens, eos_id)
146
+ )
147
+ elif attn_bias_type == "local_block_causal":
148
+ assert sliding_window is not None
149
+ assert eos_id is not None
150
+ assert tokens is not None
151
+ return fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
152
+ q_seqlen=tokens_to_seqlen(tokens, eos_id)
153
+ ).make_local_attention(sliding_window)
154
+ else:
155
+ return fmha.attn_bias.LocalAttentionFromBottomRightMask(
156
+ window_left=sliding_window - 1, window_right=0
157
+ )
158
  elif attn_impl == "sdpa":
159
+ BLT_SUPPRESS_ATTN_ERROR = int(os.environ.get("BLT_SUPPRESS_ATTN_ERROR", 0))
160
+
161
+ if attn_bias_type == "causal":
162
+ return "causal"
163
+
164
+ if BLT_SUPPRESS_ATTN_ERROR == 1:
165
+ logging.warning(
166
+ "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. Allowing model to run since BLT_SUPPRESS_ATTN_ERROR=1"
167
+ )
168
+ return "causal"
169
+ else:
170
+ raise ValueError(
171
+ "SDPA attention being used, which doesn't have specialized attention implementations for block_causal and local_block_causal attention. To suppress this error and run the model anyway, set the environment variable BLT_SUPPRESS_ATTN_ERROR=1"
172
+ )
173
  elif attn_impl == "flex_attention":
174
  return create_block_mask(causal_mask, None, None, seqlen, seqlen)
175
  elif attn_impl == "fmha":
bytelatent/preprocess/fsspec_target.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fsspec
2
+ from luigi.target import FileSystem, FileSystemTarget
3
+
4
+
5
+ class FSSpecFileSystem(FileSystem):
6
+ def __init__(self, fs: fsspec.AbstractFileSystem):
7
+ self.fs = fs
8
+
9
+ def exists(self, path):
10
+ return self.fs.exists()
11
+
12
+ def remove(self, path, recursive=True, skip_trash=True):
13
+ raise NotImplementedError()
14
+
15
+ def isdir(self, path):
16
+ return self.fs.isdir(path)
17
+
18
+ def listdir(self, path):
19
+ return self.fs.ls(path)
20
+
21
+
22
+ class FSSpecTarget(FileSystemTarget):
23
+ def __init__(self, path, fs: fsspec.AbstractFileSystem | None = None):
24
+ self.path = path
25
+ if fs is None:
26
+ self.fsspec_fs = fsspec.filesystem("file")
27
+ else:
28
+ self.fsspec_fs = fs
29
+ self._fs = None
30
+
31
+ @property
32
+ def fs(self):
33
+ if self._fs is None:
34
+ self._fs = FSSpecFileSystem(self.fsspec_fs)
35
+ return self._fs
36
+
37
+ def open(self, mode):
38
+ return self.fs.open(self.path, mode=mode)
bytelatent/test_blt.py CHANGED
@@ -23,9 +23,10 @@ from bytelatent.model.blt import (
23
  init_embeddings,
24
  patch_ids_from_lengths,
25
  )
26
- from bytelatent.model.transformer import CrossAttention
27
  from bytelatent.model.utils import create_causal_mask
28
  from bytelatent.optim import OptimArgs, build_optimizer
 
29
  from bytelatent.train import compute_loss
30
 
31
 
@@ -51,7 +52,7 @@ def batch_to_tensors_and_gpu(batch):
51
 
52
 
53
  def fake_batch():
54
- batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"))
55
  del batch_dict["x2"]
56
  del batch_dict["y2"]
57
  del batch_dict["src_names"]
@@ -98,18 +99,17 @@ def create_args(cross_attention=False):
98
  recompute_attn=False,
99
  custom_bwd=False,
100
  layer_ckpt="none",
101
- efficient_attn="sdpa",
102
- patch_only_encoder=False,
103
- patch_only_decoder=False,
104
  use_local_encoder_transformer=True,
105
  init_use_gaussian=True,
106
  init_use_depth="current",
107
  attn_bias_type="block_causal",
 
108
  alpha_depth="disabled",
109
  max_length=256,
110
  local_attention_window_len=512,
111
  max_seqlen=12288,
112
  downsampling_by_pooling="max",
 
113
  )
114
  return transformer_args
115
 
@@ -341,10 +341,15 @@ class TestByteLatentTransformer:
341
  model = ByteLatentTransformer(args)
342
  assert model is not None
343
 
344
- @pytest.mark.parametrize("attn_type", ["fmha", "sdpa"])
345
- def test_blt_transformer_forward(self, attn_type):
346
  args = create_args()
347
- args = args.model_copy(update=dict(efficient_attn=attn_type))
 
 
 
 
 
348
  model = ByteLatentTransformer(args)
349
  model = model.cuda()
350
  batch = fake_batch()
@@ -393,7 +398,9 @@ class TestByteLatentTransformer:
393
  n_kv_heads=4,
394
  norm_eps=1e-6,
395
  ).to("cuda")
396
- mask = create_causal_mask(x.shape[1], "flex_attention", sliding_window=None)
 
 
397
  output = cross_attention(x, kv, mask)
398
  assert output is not None
399
  assert output.shape == (2, 256, 512)
@@ -440,7 +447,7 @@ class TestByteLatentTransformer:
440
 
441
  def test_loss_backward(self):
442
  args = create_args()
443
- args = args.model_copy(update=dict(efficient_attn="sdpa"))
444
  batch = fake_batch()
445
  model = ByteLatentTransformer(args)
446
  steps = 10
 
23
  init_embeddings,
24
  patch_ids_from_lengths,
25
  )
26
+ from bytelatent.model.latent_transformer import CrossAttention
27
  from bytelatent.model.utils import create_causal_mask
28
  from bytelatent.optim import OptimArgs, build_optimizer
29
+ from bytelatent.tokenizers.constants import EOS_ID
30
  from bytelatent.train import compute_loss
31
 
32
 
 
52
 
53
 
54
  def fake_batch():
55
+ batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False)
56
  del batch_dict["x2"]
57
  del batch_dict["y2"]
58
  del batch_dict["src_names"]
 
99
  recompute_attn=False,
100
  custom_bwd=False,
101
  layer_ckpt="none",
 
 
 
102
  use_local_encoder_transformer=True,
103
  init_use_gaussian=True,
104
  init_use_depth="current",
105
  attn_bias_type="block_causal",
106
+ attn_impl="xformers",
107
  alpha_depth="disabled",
108
  max_length=256,
109
  local_attention_window_len=512,
110
  max_seqlen=12288,
111
  downsampling_by_pooling="max",
112
+ eos_id=EOS_ID,
113
  )
114
  return transformer_args
115
 
 
341
  model = ByteLatentTransformer(args)
342
  assert model is not None
343
 
344
+ @pytest.mark.parametrize("attn_impl", ["sdpa", "xformers"])
345
+ def test_blt_transformer_forward(self, attn_impl):
346
  args = create_args()
347
+ if attn_impl == "sdpa":
348
+ os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1"
349
+ else:
350
+ os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0"
351
+
352
+ args = args.model_copy(update=dict(attn_impl=attn_impl))
353
  model = ByteLatentTransformer(args)
354
  model = model.cuda()
355
  batch = fake_batch()
 
398
  n_kv_heads=4,
399
  norm_eps=1e-6,
400
  ).to("cuda")
401
+ mask = create_causal_mask(
402
+ x.shape[1], "flex_attention", None, sliding_window=None
403
+ )
404
  output = cross_attention(x, kv, mask)
405
  assert output is not None
406
  assert output.shape == (2, 256, 512)
 
447
 
448
  def test_loss_backward(self):
449
  args = create_args()
450
+ args = args.model_copy(update=dict(attn_impl="xformers"))
451
  batch = fake_batch()
452
  model = ByteLatentTransformer(args)
453
  steps = 10
bytelatent/test_entropy_model.py CHANGED
@@ -24,6 +24,7 @@ def test_entropy_model():
24
  dataset_files=[ARROW_TEST_DATA],
25
  row_num=0,
26
  arrow_batch_size=100,
 
27
  )
28
  arrow_file = initial_state.build()
29
  tokenizer_args = TokenizerArgs(
@@ -38,7 +39,7 @@ def test_entropy_model():
38
  BLT_DATA,
39
  "entropy_model.pth",
40
  ),
41
- )
42
  preprocess_iter = PreprocessIterator(
43
  arrow_file,
44
  tokenizer_args=tokenizer_args,
@@ -48,8 +49,10 @@ def test_entropy_model():
48
  for example in preprocess_iter.create_iter():
49
  tokens = torch.tensor(example.tokens).unsqueeze(0)
50
  expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
51
- preds = entropy_model(tokens)
52
  pred_entropies = entropy(preds)
53
  assert pred_entropies.shape == expected_entropies.shape
54
- assert torch.allclose(pred_entropies, expected_entropies, rtol=1.0, atol=3.5)
 
 
55
  break
 
24
  dataset_files=[ARROW_TEST_DATA],
25
  row_num=0,
26
  arrow_batch_size=100,
27
+ s3_profile=None,
28
  )
29
  arrow_file = initial_state.build()
30
  tokenizer_args = TokenizerArgs(
 
39
  BLT_DATA,
40
  "entropy_model.pth",
41
  ),
42
+ ).cuda()
43
  preprocess_iter = PreprocessIterator(
44
  arrow_file,
45
  tokenizer_args=tokenizer_args,
 
49
  for example in preprocess_iter.create_iter():
50
  tokens = torch.tensor(example.tokens).unsqueeze(0)
51
  expected_entropies = torch.tensor(example.entropies).unsqueeze(0)
52
+ preds = entropy_model(tokens.cuda())
53
  pred_entropies = entropy(preds)
54
  assert pred_entropies.shape == expected_entropies.shape
55
+ assert torch.allclose(
56
+ pred_entropies.cpu(), expected_entropies, rtol=1.0, atol=3.5
57
+ )
58
  break
bytelatent/train.py CHANGED
@@ -644,6 +644,10 @@ def main():
644
  cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
645
  cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
646
  train_args = TrainArgs.model_validate(cfg)
 
 
 
 
647
  train(train_args)
648
 
649
 
 
644
  cfg = OmegaConf.merge(default_cfg, file_cfg, cli_args)
645
  cfg = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
646
  train_args = TrainArgs.model_validate(cfg)
647
+ if train_args.debug_dynamo:
648
+ import torch._dynamo
649
+
650
+ torch._dynamo.config.suppress_errors = True
651
  train(train_args)
652
 
653
 
bytelatent/transformer.py CHANGED
@@ -22,23 +22,7 @@ from bytelatent.base_transformer import (
22
  RMSNorm,
23
  cross_entropy,
24
  )
25
-
26
-
27
- def create_causal_mask(seqlen, attn_impl, sliding_window):
28
- if sliding_window is not None and attn_impl == "xformers":
29
- return fmha.attn_bias.LocalAttentionFromBottomRightMask(
30
- window_left=sliding_window - 1, window_right=0
31
- )
32
- elif attn_impl == "xformers":
33
- return fmha.attn_bias.LowerTriangularMask()
34
- elif attn_impl == "sdpa":
35
- return "causal"
36
- elif attn_impl == "flex_attention":
37
- return create_block_mask(causal_mask, None, None, seqlen, seqlen)
38
- else:
39
- raise NotImplementedError(
40
- f"Attention {attn_impl} with {sliding_window} sliding window not implemented"
41
- )
42
 
43
 
44
  def attention_flops_per_token(n_layers, seq_len, dim, causal):
@@ -94,8 +78,10 @@ class LMTransformer(BaseTransformer):
94
  target: Optional[torch.Tensor] = None,
95
  tok_idx: Optional[torch.Tensor] = None,
96
  mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
97
- attn_impl: str = "sdpa",
98
  ):
 
 
99
  bsz, seqlen = token_values.shape
100
 
101
  h = self.tok_embeddings(token_values)
@@ -103,7 +89,14 @@ class LMTransformer(BaseTransformer):
103
  mask = (
104
  mask
105
  if mask is not None
106
- else create_causal_mask(seqlen, attn_impl, self.sliding_window)
 
 
 
 
 
 
 
107
  )
108
  h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
109
 
 
22
  RMSNorm,
23
  cross_entropy,
24
  )
25
+ from bytelatent.model.utils import create_causal_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  def attention_flops_per_token(n_layers, seq_len, dim, causal):
 
78
  target: Optional[torch.Tensor] = None,
79
  tok_idx: Optional[torch.Tensor] = None,
80
  mask: Optional[Union[BlockMask, AttentionBias, torch.Tensor, str]] = None,
81
+ attn_impl: str | None = None,
82
  ):
83
+ if attn_impl is None:
84
+ attn_impl = self.attn_impl
85
  bsz, seqlen = token_values.shape
86
 
87
  h = self.tok_embeddings(token_values)
 
89
  mask = (
90
  mask
91
  if mask is not None
92
+ else create_causal_mask(
93
+ seqlen,
94
+ attn_impl,
95
+ self.attn_bias_type,
96
+ sliding_window=self.sliding_window,
97
+ tokens=token_values,
98
+ eos_id=self.eos_id,
99
+ )
100
  )
101
  h = super().forward(h, tok_idx=tok_idx, mask=mask, attn_impl=attn_impl)
102