Spaces:
Runtime error
Runtime error
Update infer.py
Browse files
infer.py
CHANGED
|
@@ -147,25 +147,23 @@ def get_model_and_tokenizer(model_name: str, optimization_level: str, progress):
|
|
| 147 |
)
|
| 148 |
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
# return_tensors="pt"
|
| 164 |
-
# )
|
| 165 |
|
| 166 |
-
|
| 167 |
|
| 168 |
-
|
| 169 |
|
| 170 |
|
| 171 |
@torch.inference_mode()
|
|
@@ -247,8 +245,8 @@ def batch_embed(
|
|
| 247 |
|
| 248 |
start_time = time.time()
|
| 249 |
|
| 250 |
-
collator =
|
| 251 |
-
|
| 252 |
)
|
| 253 |
|
| 254 |
dl = DataLoader(
|
|
|
|
| 147 |
)
|
| 148 |
|
| 149 |
|
| 150 |
+
def collate_fn(examples, column_name, tokenizer):
|
| 151 |
+
feature_cols = ["input_ids", "attention_mask"]
|
| 152 |
+
features = [{k: x[k] for k in feature_cols} for x in examples]
|
| 153 |
+
|
| 154 |
+
print(features)
|
| 155 |
+
|
| 156 |
+
tokenized = tokenizer.pad(
|
| 157 |
+
features,
|
| 158 |
+
padding=True,
|
| 159 |
+
max_length=512,
|
| 160 |
+
return_tensors="pt",
|
| 161 |
+
pad_to_multiple_of=16,
|
| 162 |
+
)
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
tokenized[column_name] = [x[column_name] for x in examples]
|
| 165 |
|
| 166 |
+
return tokenized
|
| 167 |
|
| 168 |
|
| 169 |
@torch.inference_mode()
|
|
|
|
| 245 |
|
| 246 |
start_time = time.time()
|
| 247 |
|
| 248 |
+
collator = partial(
|
| 249 |
+
collate_fn, column_name=column_name, tokenizer=tokenizer
|
| 250 |
)
|
| 251 |
|
| 252 |
dl = DataLoader(
|