Update modeling_skywork_lm2.py
Browse files- modeling_skywork_lm2.py +2 -2
modeling_skywork_lm2.py
CHANGED
@@ -600,7 +600,7 @@ class SkyworkLM2FlashAttention2(SkyworkLM2Attention):
|
|
600 |
)
|
601 |
|
602 |
|
603 |
-
|
604 |
'eager': SkyworkLM2Attention,
|
605 |
'flash_attention_2': SkyworkLM2FlashAttention2,
|
606 |
}
|
@@ -612,7 +612,7 @@ class SkyworkLM2DecoderLayer(nn.Module):
|
|
612 |
super().__init__()
|
613 |
self.hidden_size = config.hidden_size
|
614 |
|
615 |
-
self.attention =
|
616 |
|
617 |
self.feed_forward = SkyworkLM2MLP(config)
|
618 |
self.attention_norm = SkyworkLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
600 |
)
|
601 |
|
602 |
|
603 |
+
LM2_ATTENTION_CLASSES = {
|
604 |
'eager': SkyworkLM2Attention,
|
605 |
'flash_attention_2': SkyworkLM2FlashAttention2,
|
606 |
}
|
|
|
612 |
super().__init__()
|
613 |
self.hidden_size = config.hidden_size
|
614 |
|
615 |
+
self.attention = LM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
|
616 |
|
617 |
self.feed_forward = SkyworkLM2MLP(config)
|
618 |
self.attention_norm = SkyworkLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|