Update ensemble_qwen.py
Browse files- ensemble_qwen.py +2 -3
ensemble_qwen.py
CHANGED
@@ -24,10 +24,9 @@ class EnsembleConfig(Qwen3Config):
|
|
24 |
|
25 |
class EnsembleForCausalLM(PreTrainedModel, GenerationMixin):
|
26 |
config_class = EnsembleConfig
|
27 |
-
_supports_flash_attn_2 = True
|
28 |
-
_supports_sdpa = True
|
29 |
main_input_name = "input_ids"
|
30 |
-
_tied_weights_keys = ["model_a.lm_head.weight", "model_b.lm_head.weight"]
|
31 |
_tp_plan = {"model_a.lm_head": "colwise_rep", "model_b.lm_head": "colwise_rep"}
|
32 |
|
33 |
def __init__(self, config: EnsembleConfig):
|
|
|
24 |
|
25 |
class EnsembleForCausalLM(PreTrainedModel, GenerationMixin):
|
26 |
config_class = EnsembleConfig
|
27 |
+
_supports_flash_attn_2 = True
|
28 |
+
_supports_sdpa = True
|
29 |
main_input_name = "input_ids"
|
|
|
30 |
_tp_plan = {"model_a.lm_head": "colwise_rep", "model_b.lm_head": "colwise_rep"}
|
31 |
|
32 |
def __init__(self, config: EnsembleConfig):
|