Fizzarolli commited on
Commit
ef095ca
·
verified ·
1 Parent(s): 8f10a7a

Update modernberg_model.py

Browse files
Files changed (1) hide show
  1. modernberg_model.py +1 -13
modernberg_model.py CHANGED
@@ -411,8 +411,6 @@ class GriffinRecurrentblock(nn.Module):
411
  input_states: torch.Tensor,
412
  position_ids: torch.Tensor,
413
  attention_mask: torch.Tensor,
414
- cache_position: torch.Tensor,
415
- use_cache: bool = True,
416
  **kwargs
417
  ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
418
  _, seq_len, _ = input_states.shape
@@ -423,17 +421,7 @@ class GriffinRecurrentblock(nn.Module):
423
  x_branch = self.linear_x(input_states)
424
  x_branch = x_branch.transpose(1, 2)
425
 
426
- if use_cache:
427
- if cache_position.shape[0] != 1: # prefill
428
- self.conv1d_state = nn.functional.pad(x_branch, (self.conv1d_width - x_branch.shape[-1] - 1, 0))
429
- x_branch = self.conv_1d(x_branch)[..., :seq_len]
430
- else: # decoding
431
- conv_state = torch.cat((self.conv1d_state, x_branch), -1)
432
- x_branch = torch.sum(conv_state * self.conv_1d.weight[:, 0, :], dim=-1) + self.conv_1d.bias
433
- x_branch = x_branch.unsqueeze(-1)
434
- self.conv1d_state = conv_state[:, :, 1:]
435
- else:
436
- x_branch = self.conv_1d(x_branch)[..., :seq_len]
437
 
438
  x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids)
439
 
 
411
  input_states: torch.Tensor,
412
  position_ids: torch.Tensor,
413
  attention_mask: torch.Tensor,
 
 
414
  **kwargs
415
  ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
416
  _, seq_len, _ = input_states.shape
 
421
  x_branch = self.linear_x(input_states)
422
  x_branch = x_branch.transpose(1, 2)
423
 
424
+ x_branch = self.conv_1d(x_branch)[..., :seq_len]
 
 
 
 
 
 
 
 
 
 
425
 
426
  x_branch = self.rg_lru(x_branch.transpose(1, 2), position_ids)
427