HongyuanTao commited on
Commit
6790e0a
·
verified ·
1 Parent(s): 9598d88

Update modeling_mmMamba.py

Browse files
Files changed (1) hide show
  1. modeling_mmMamba.py +93 -7
modeling_mmMamba.py CHANGED
@@ -24,22 +24,20 @@ import torch.nn.functional as F
24
  import torch.utils.checkpoint
25
  from einops import rearrange
26
  from torch import nn
27
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
  from transformers.activations import ACT2FN
29
  from transformers.modeling_outputs import (BaseModelOutputWithPast,
30
- CausalLMOutputWithPast,
31
- SequenceClassifierOutputWithPast)
32
  from transformers.modeling_utils import PreTrainedModel
33
  from transformers.utils import (add_start_docstrings,
34
  add_start_docstrings_to_model_forward, logging,
35
  replace_return_docstrings)
36
- from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
37
- import copy
38
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
39
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
40
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
41
- from transformers.cache_utils import Cache
42
- import time
43
 
44
  try:
45
  from transformers.generation.streamers import BaseStreamer
@@ -130,6 +128,94 @@ class mmMambaRMSNorm(nn.Module):
130
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
131
  return self.weight * hidden_states.to(input_dtype)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  class mmMambaMLP(nn.Module):
134
  def __init__(self, config):
135
  super().__init__()
 
24
  import torch.utils.checkpoint
25
  from einops import rearrange
26
  from torch import nn
27
+ from torch.nn import CrossEntropyLoss
28
  from transformers.activations import ACT2FN
29
  from transformers.modeling_outputs import (BaseModelOutputWithPast,
30
+ CausalLMOutputWithPast)
 
31
  from transformers.modeling_utils import PreTrainedModel
32
  from transformers.utils import (add_start_docstrings,
33
  add_start_docstrings_to_model_forward, logging,
34
  replace_return_docstrings)
35
+ from fused_norm_gate import FusedRMSNormSwishGate
36
+
37
  from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
38
  from mamba_ssm.ops.triton.selective_state_update import selective_state_update
39
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
40
+
 
41
 
42
  try:
43
  from transformers.generation.streamers import BaseStreamer
 
128
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
129
  return self.weight * hidden_states.to(input_dtype)
130
 
131
+
132
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->mmMamba
133
+ class mmMambaRotaryEmbedding(nn.Module):
134
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
135
+ super().__init__()
136
+
137
+ self.dim = dim
138
+ self.max_position_embeddings = max_position_embeddings
139
+ self.base = base
140
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
141
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
142
+
143
+ # Build here to make `torch.jit.trace` work.
144
+ self._set_cos_sin_cache(
145
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
146
+ )
147
+
148
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
149
+ self.max_seq_len_cached = seq_len
150
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
151
+
152
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
153
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
154
+ emb = torch.cat((freqs, freqs), dim=-1)
155
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
156
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
157
+
158
+ def forward(self, x, seq_len=None):
159
+ # x: [bs, num_attention_heads, seq_len, head_size]
160
+ if seq_len > self.max_seq_len_cached:
161
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
162
+
163
+ return (
164
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
165
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
166
+ )
167
+
168
+
169
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->mmMamba
170
+ class mmMambaLinearScalingRotaryEmbedding(mmMambaRotaryEmbedding):
171
+ """mmMambaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
172
+
173
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
174
+ self.scaling_factor = scaling_factor
175
+ super().__init__(dim, max_position_embeddings, base, device)
176
+
177
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
178
+ self.max_seq_len_cached = seq_len
179
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
180
+ t = t / self.scaling_factor
181
+
182
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
183
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
184
+ emb = torch.cat((freqs, freqs), dim=-1)
185
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
186
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
187
+
188
+
189
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->mmMamba
190
+ class mmMambaDynamicNTKScalingRotaryEmbedding(mmMambaRotaryEmbedding):
191
+ """mmMambaRotaryEmbedding extended with Dynamic NTK scaling.
192
+ Credits to the Reddit users /u/bloc97 and /u/emozilla.
193
+ """
194
+
195
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
196
+ self.scaling_factor = scaling_factor
197
+ super().__init__(dim, max_position_embeddings, base, device)
198
+
199
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
200
+ self.max_seq_len_cached = seq_len
201
+
202
+ if seq_len > self.max_position_embeddings:
203
+ base = self.base * (
204
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
205
+ ) ** (self.dim / (self.dim - 2))
206
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
207
+ self.register_buffer('inv_freq', inv_freq, persistent=False)
208
+
209
+ t = torch.arange(self.max_seq_len_cached, device=device).to(dtype=self.inv_freq.dtype)
210
+
211
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
212
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
213
+ emb = torch.cat((freqs, freqs), dim=-1)
214
+ self.register_buffer('cos_cached', emb.cos().to(dtype), persistent=False)
215
+ self.register_buffer('sin_cached', emb.sin().to(dtype), persistent=False)
216
+
217
+
218
+
219
  class mmMambaMLP(nn.Module):
220
  def __init__(self, config):
221
  super().__init__()