Allow return attention weights
#1
by
orby-yanan
- opened
- blocks.py +2 -2
- modeling_mpt.py +1 -1
blocks.py
CHANGED
@@ -31,9 +31,9 @@ class MPTBlock(nn.Module):
|
|
31 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
32 |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
33 |
|
34 |
-
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
35 |
a = self.norm_1(x)
|
36 |
-
(b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
|
37 |
x = x + self.resid_attn_dropout(b)
|
38 |
m = self.norm_2(x)
|
39 |
n = self.ffn(m)
|
|
|
31 |
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
|
32 |
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
|
33 |
|
34 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
|
35 |
a = self.norm_1(x)
|
36 |
+
(b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, needs_weights=needs_weights)
|
37 |
x = x + self.resid_attn_dropout(b)
|
38 |
m = self.norm_2(x)
|
39 |
n = self.ffn(m)
|
modeling_mpt.py
CHANGED
@@ -199,7 +199,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
199 |
assert all_hidden_states is not None
|
200 |
all_hidden_states = all_hidden_states + (x,)
|
201 |
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
202 |
-
(x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
|
203 |
if past_key_values is not None:
|
204 |
past_key_values[b_idx] = past_key_value
|
205 |
if output_attentions:
|
|
|
199 |
assert all_hidden_states is not None
|
200 |
all_hidden_states = all_hidden_states + (x,)
|
201 |
past_key_value = past_key_values[b_idx] if past_key_values is not None else None
|
202 |
+
(x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, needs_weights=output_attentions)
|
203 |
if past_key_values is not None:
|
204 |
past_key_values[b_idx] = past_key_value
|
205 |
if output_attentions:
|