Update modernberg_model.py
Browse files- modernberg_model.py +1 -13
modernberg_model.py
CHANGED
@@ -411,8 +411,6 @@ class GriffinRecurrentblock(nn.Module):
|
|
411 |
input_states: torch.Tensor,
|
412 |
position_ids: torch.Tensor,
|
413 |
attention_mask: torch.Tensor,
|
414 |
-
cache_position: torch.Tensor,
|
415 |
-
use_cache: bool = True,
|
416 |
**kwargs
|
417 |
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
418 |
_, seq_len, _ = input_states.shape
|
@@ -423,17 +421,7 @@ class GriffinRecurrentblock(nn.Module):
|
|
423 |
x_branch = self.linear_x(input_states)
|
424 |
x_branch = x_branch.transpose(1, 2)
|
425 |
|
426 |
-
|
427 |
-
if cache_position.shape[0] != 1: # prefill
|
428 |
-
self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0))
|
429 |
-
x_branch = self.conv_1d(x_branch)[..., :seq_len]
|
430 |
-
else: # decoding
|
431 |
-
conv_state = torch.cat((self.conv1d_state, x_branch), -1)
|
432 |
-
x_branch = torch.sum(conv_state * self.conv_1d.weight[:, 0, :], dim=-1) + self.conv_1d.bias
|
433 |
-
x_branch = x_branch.unsqueeze(-1)
|
434 |
-
self.conv1d_state = conv_state[:, :, 1:]
|
435 |
-
else:
|
436 |
-
x_branch = self.conv_1d(x_branch)[..., :seq_len]
|
437 |
|
438 |
x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids)
|
439 |
|
|
|
411 |
input_states: torch.Tensor,
|
412 |
position_ids: torch.Tensor,
|
413 |
attention_mask: torch.Tensor,
|
|
|
|
|
414 |
**kwargs
|
415 |
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
416 |
_, seq_len, _ = input_states.shape
|
|
|
421 |
x_branch = self.linear_x(input_states)
|
422 |
x_branch = x_branch.transpose(1, 2)
|
423 |
|
424 |
+
x_branch = self.conv_1d(x_branch)[..., :seq_len]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
|
426 |
x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids)
|
427 |
|