yushihu commited on
Commit
ea5c483
·
verified ·
1 Parent(s): 2d21c51

Update ensemble_qwen.py

Browse files
Files changed (1) hide show
  1. ensemble_qwen.py +2 -3
ensemble_qwen.py CHANGED
@@ -24,10 +24,9 @@ class EnsembleConfig(Qwen3Config):
24
 
25
  class EnsembleForCausalLM(PreTrainedModel, GenerationMixin):
26
  config_class = EnsembleConfig
27
- _supports_flash_attn_2 = True # <── NEW
28
- _supports_sdpa = True # optional, lets users pick "sdpa"
29
  main_input_name = "input_ids"
30
- _tied_weights_keys = ["model_a.lm_head.weight", "model_b.lm_head.weight"]
31
  _tp_plan = {"model_a.lm_head": "colwise_rep", "model_b.lm_head": "colwise_rep"}
32
 
33
  def __init__(self, config: EnsembleConfig):
 
24
 
25
  class EnsembleForCausalLM(PreTrainedModel, GenerationMixin):
26
  config_class = EnsembleConfig
27
+ _supports_flash_attn_2 = True
28
+ _supports_sdpa = True
29
  main_input_name = "input_ids"
 
30
  _tp_plan = {"model_a.lm_head": "colwise_rep", "model_b.lm_head": "colwise_rep"}
31
 
32
  def __init__(self, config: EnsembleConfig):