luca-peric commited on
Commit
a37fec7
·
1 Parent(s): 661d10b

local block causal when cuda avail

Browse files
Files changed (1) hide show
  1. bytelatent/entropy_model.py +1 -1
bytelatent/entropy_model.py CHANGED
@@ -27,7 +27,7 @@ def load_entropy_model(entropy_model_checkpoint_dir, state_dict_path, device="cp
27
  max_seqlen=model_params["max_seqlen"],
28
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
29
  vocab_size=model_params["vocab_size"],
30
- attn_bias_type="causal",
31
  attn_impl="xformers" if torch.cuda.is_available() else "sdpa",
32
  sliding_window=512,
33
  )
 
27
  max_seqlen=model_params["max_seqlen"],
28
  ffn_dim_multiplier=model_params["ffn_dim_multiplier"],
29
  vocab_size=model_params["vocab_size"],
30
+ attn_bias_type="local_block_causal" if torch.cuda.is_available() else "causal",
31
  attn_impl="xformers" if torch.cuda.is_available() else "sdpa",
32
  sliding_window=512,
33
  )