Spaces:
Running
on
Zero
Running
on
Zero
Fix eval mask (#93)
Browse filesCo-authored-by: Srini Iyer <[email protected]>
- bytelatent/args.py +3 -0
- bytelatent/eval.py +6 -2
bytelatent/args.py
CHANGED
@@ -260,6 +260,9 @@ class ValidationArgs(BaseModel):
|
|
260 |
max_n_docs: int | None = (
|
261 |
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
262 |
)
|
|
|
|
|
|
|
263 |
use_val_from_train_src: bool = True # Use the validation set from training sources
|
264 |
root_dir: str = ""
|
265 |
sources: list[str] = [] # Other sources to eval on
|
|
|
260 |
max_n_docs: int | None = (
|
261 |
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
262 |
)
|
263 |
+
max_n_batches: int | None = (
|
264 |
+
None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
|
265 |
+
)
|
266 |
use_val_from_train_src: bool = True # Use the validation set from training sources
|
267 |
root_dir: str = ""
|
268 |
sources: list[str] = [] # Other sources to eval on
|
bytelatent/eval.py
CHANGED
@@ -153,6 +153,7 @@ def eval_ppl_on_path(
|
|
153 |
path: str,
|
154 |
arrow_batch_size: int,
|
155 |
max_n_docs: int | None,
|
|
|
156 |
s3_profile: str | None = None,
|
157 |
):
|
158 |
model.eval()
|
@@ -189,7 +190,9 @@ def eval_ppl_on_path(
|
|
189 |
total_loss = 0.0
|
190 |
n_bytes = 0
|
191 |
batch_iterator = packing_iterator.create_iter()
|
192 |
-
for batch in batch_iterator:
|
|
|
|
|
193 |
x = torch.from_numpy(batch.x).cuda()
|
194 |
y = torch.from_numpy(batch.y).cuda()
|
195 |
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
@@ -203,7 +206,7 @@ def eval_ppl_on_path(
|
|
203 |
pred = model(x, patch_lengths=patch_lengths)
|
204 |
else:
|
205 |
pred = model(x)
|
206 |
-
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
|
207 |
total_loss += loss.item()
|
208 |
else:
|
209 |
raise NotImplementedError()
|
@@ -301,6 +304,7 @@ def launch_eval(eval_args: EvalArgs):
|
|
301 |
add_patches=train_cfg.data.add_patches,
|
302 |
path=os.path.join(eval_args.validation.root_dir, source),
|
303 |
max_n_docs=eval_args.validation.max_n_docs,
|
|
|
304 |
arrow_batch_size=20,
|
305 |
s3_profile=eval_args.s3_profile,
|
306 |
)
|
|
|
153 |
path: str,
|
154 |
arrow_batch_size: int,
|
155 |
max_n_docs: int | None,
|
156 |
+
max_n_batches: int | None,
|
157 |
s3_profile: str | None = None,
|
158 |
):
|
159 |
model.eval()
|
|
|
190 |
total_loss = 0.0
|
191 |
n_bytes = 0
|
192 |
batch_iterator = packing_iterator.create_iter()
|
193 |
+
for i, batch in enumerate(batch_iterator):
|
194 |
+
if i == max_n_batches:
|
195 |
+
break
|
196 |
x = torch.from_numpy(batch.x).cuda()
|
197 |
y = torch.from_numpy(batch.y).cuda()
|
198 |
mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
|
|
|
206 |
pred = model(x, patch_lengths=patch_lengths)
|
207 |
else:
|
208 |
pred = model(x)
|
209 |
+
loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0)
|
210 |
total_loss += loss.item()
|
211 |
else:
|
212 |
raise NotImplementedError()
|
|
|
304 |
add_patches=train_cfg.data.add_patches,
|
305 |
path=os.path.join(eval_args.validation.root_dir, source),
|
306 |
max_n_docs=eval_args.validation.max_n_docs,
|
307 |
+
max_n_batches=eval_args.validation.max_n_batches,
|
308 |
arrow_batch_size=20,
|
309 |
s3_profile=eval_args.s3_profile,
|
310 |
)
|