Update resampler.py
Browse files- resampler.py +2 -2
resampler.py
CHANGED
@@ -237,7 +237,7 @@ class Resampler(nn.Module):
|
|
237 |
|
238 |
|
239 |
def batch_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
|
240 |
-
bs = k.shape[
|
241 |
|
242 |
if pos_embed_temporal:
|
243 |
# temporal 缁村害鎶樺彔
|
@@ -272,7 +272,7 @@ class Resampler(nn.Module):
|
|
272 |
|
273 |
|
274 |
def foreach_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
|
275 |
-
bs = k.shape[
|
276 |
|
277 |
if pos_embed_temporal:
|
278 |
k += torch.stack(pos_embed_temporal, dim=0)
|
|
|
237 |
|
238 |
|
239 |
def batch_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
|
240 |
+
bs = k.shape[1]
|
241 |
|
242 |
if pos_embed_temporal:
|
243 |
# temporal 缁村害鎶樺彔
|
|
|
272 |
|
273 |
|
274 |
def foreach_attn_forward(self, q, k, v, pos_embed_temporal, temporal_ids, key_padding_mask):
|
275 |
+
bs = k.shape[1]
|
276 |
|
277 |
if pos_embed_temporal:
|
278 |
k += torch.stack(pos_embed_temporal, dim=0)
|