Update modularStarEncoder.py
Browse files- modularStarEncoder.py +3 -3
modularStarEncoder.py
CHANGED
|
@@ -205,15 +205,15 @@ def get_pooling_mask(
|
|
| 205 |
repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
|
| 206 |
|
| 207 |
DEVICE = input_ids.get_device()
|
| 208 |
-
print(DEVICE)
|
| 209 |
|
| 210 |
if DEVICE<0:
|
| 211 |
DEVICE = "cpu"
|
| 212 |
ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
print(repeated_idx.get_device(),ranges.get_device())
|
| 215 |
pooling_mask = (repeated_idx <= ranges).long()
|
| 216 |
-
|
| 217 |
|
| 218 |
return pooling_mask
|
| 219 |
|
|
|
|
| 205 |
repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
|
| 206 |
|
| 207 |
DEVICE = input_ids.get_device()
|
|
|
|
| 208 |
|
| 209 |
if DEVICE<0:
|
| 210 |
DEVICE = "cpu"
|
| 211 |
ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
|
| 212 |
+
ranges.to(DEVICE)
|
| 213 |
+
repeated_idx.to(DEVICE)
|
| 214 |
|
|
|
|
| 215 |
pooling_mask = (repeated_idx <= ranges).long()
|
| 216 |
+
|
| 217 |
|
| 218 |
return pooling_mask
|
| 219 |
|