qianyuchen commited on
Commit
85f72b2
verified
1 Parent(s): 5a68d09

Update resampler.py

Browse files
Files changed (1) hide show
  1. resampler.py +2 -2
resampler.py CHANGED
@@ -237,7 +237,7 @@ class Resampler(nn.Module):
237
 
238
 
239
  def batch_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
240
- bs = k.shape[0]
241
 
242
  if pos_embed_temporal:
243
  # temporal 缁村害鎶樺彔
@@ -272,7 +272,7 @@ class Resampler(nn.Module):
272
 
273
 
274
  def foreach_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
275
- bs = k.shape[0]
276
 
277
  if pos_embed_temporal:
278
  k += torch.stack(pos_embed_temporal, dim=0)
 
237
 
238
 
239
  def batch_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
240
+ bs = k.shape[1]
241
 
242
  if pos_embed_temporal:
243
  # temporal 缁村害鎶樺彔
 
272
 
273
 
274
  def foreach_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
275
+ bs = k.shape[1]
276
 
277
  if pos_embed_temporal:
278
  k += torch.stack(pos_embed_temporal, dim=0)