Srinivasan Iyer sviyer commited on
Commit
19a3f75
·
unverified ·
1 Parent(s): 8c1b1a7

Fix eval mask (#93)

Browse files

Co-authored-by: Srini Iyer <[email protected]>

Files changed (2) hide show
  1. bytelatent/args.py +3 -0
  2. bytelatent/eval.py +6 -2
bytelatent/args.py CHANGED
@@ -260,6 +260,9 @@ class ValidationArgs(BaseModel):
260
  max_n_docs: int | None = (
261
  None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
262
  )
 
 
 
263
  use_val_from_train_src: bool = True # Use the validation set from training sources
264
  root_dir: str = ""
265
  sources: list[str] = [] # Other sources to eval on
 
260
  max_n_docs: int | None = (
261
  None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
262
  )
263
+ max_n_batches: int | None = (
264
+ None # If None the whole validation file is used -> /!\ This number of steps is gpu dependent (100 max steps on 8 gpus = 800 steps on 1 gpu)
265
+ )
266
  use_val_from_train_src: bool = True # Use the validation set from training sources
267
  root_dir: str = ""
268
  sources: list[str] = [] # Other sources to eval on
bytelatent/eval.py CHANGED
@@ -153,6 +153,7 @@ def eval_ppl_on_path(
153
  path: str,
154
  arrow_batch_size: int,
155
  max_n_docs: int | None,
 
156
  s3_profile: str | None = None,
157
  ):
158
  model.eval()
@@ -189,7 +190,9 @@ def eval_ppl_on_path(
189
  total_loss = 0.0
190
  n_bytes = 0
191
  batch_iterator = packing_iterator.create_iter()
192
- for batch in batch_iterator:
 
 
193
  x = torch.from_numpy(batch.x).cuda()
194
  y = torch.from_numpy(batch.y).cuda()
195
  mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
@@ -203,7 +206,7 @@ def eval_ppl_on_path(
203
  pred = model(x, patch_lengths=patch_lengths)
204
  else:
205
  pred = model(x)
206
- loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum")
207
  total_loss += loss.item()
208
  else:
209
  raise NotImplementedError()
@@ -301,6 +304,7 @@ def launch_eval(eval_args: EvalArgs):
301
  add_patches=train_cfg.data.add_patches,
302
  path=os.path.join(eval_args.validation.root_dir, source),
303
  max_n_docs=eval_args.validation.max_n_docs,
 
304
  arrow_batch_size=20,
305
  s3_profile=eval_args.s3_profile,
306
  )
 
153
  path: str,
154
  arrow_batch_size: int,
155
  max_n_docs: int | None,
156
+ max_n_batches: int | None,
157
  s3_profile: str | None = None,
158
  ):
159
  model.eval()
 
190
  total_loss = 0.0
191
  n_bytes = 0
192
  batch_iterator = packing_iterator.create_iter()
193
+ for i, batch in enumerate(batch_iterator):
194
+ if i == max_n_batches:
195
+ break
196
  x = torch.from_numpy(batch.x).cuda()
197
  y = torch.from_numpy(batch.y).cuda()
198
  mask = None if batch.mask is None else torch.from_numpy(batch.mask).cuda()
 
206
  pred = model(x, patch_lengths=patch_lengths)
207
  else:
208
  pred = model(x)
209
+ loss = F.cross_entropy(pred.flatten(0, 1), y.flatten(0, 1), reduction="sum", ignore_index=0)
210
  total_loss += loss.item()
211
  else:
212
  raise NotImplementedError()
 
304
  add_patches=train_cfg.data.add_patches,
305
  path=os.path.join(eval_args.validation.root_dir, source),
306
  max_n_docs=eval_args.validation.max_n_docs,
307
+ max_n_batches=eval_args.validation.max_n_batches,
308
  arrow_batch_size=20,
309
  s3_profile=eval_args.s3_profile,
310
  )