jiangbop commited on
Commit
48da830
·
verified ·
1 Parent(s): a6a9fd8

Update modeling_skywork_lm2.py

Browse files
Files changed (1) hide show
  1. modeling_skywork_lm2.py +2 -2
modeling_skywork_lm2.py CHANGED
@@ -600,7 +600,7 @@ class SkyworkLM2FlashAttention2(SkyworkLM2Attention):
600
  )
601
 
602
 
603
- INTERNLM2_ATTENTION_CLASSES = {
604
  'eager': SkyworkLM2Attention,
605
  'flash_attention_2': SkyworkLM2FlashAttention2,
606
  }
@@ -612,7 +612,7 @@ class SkyworkLM2DecoderLayer(nn.Module):
612
  super().__init__()
613
  self.hidden_size = config.hidden_size
614
 
615
- self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
616
 
617
  self.feed_forward = SkyworkLM2MLP(config)
618
  self.attention_norm = SkyworkLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
600
  )
601
 
602
 
603
+ LM2_ATTENTION_CLASSES = {
604
  'eager': SkyworkLM2Attention,
605
  'flash_attention_2': SkyworkLM2FlashAttention2,
606
  }
 
612
  super().__init__()
613
  self.hidden_size = config.hidden_size
614
 
615
+ self.attention = LM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
616
 
617
  self.feed_forward = SkyworkLM2MLP(config)
618
  self.attention_norm = SkyworkLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)