min
Browse files- run_mlm_flax_stream.py +1 -1
run_mlm_flax_stream.py
CHANGED
@@ -564,7 +564,7 @@ if __name__ == "__main__":
|
|
564 |
train_metrics = []
|
565 |
eval_metrics = []
|
566 |
|
567 |
-
training_iter = iter(torch.utils.data.DataLoader(tokenized_datasets.with_format("torch"), batch_size=1, shuffle=False, num_workers=
|
568 |
|
569 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
570 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|
|
|
564 |
train_metrics = []
|
565 |
eval_metrics = []
|
566 |
|
567 |
+
training_iter = iter(torch.utils.data.DataLoader(tokenized_datasets.with_format("torch"), batch_size=1, shuffle=False, num_workers=min(33,dataset.n_shards), collate_fn=lambda x: x))
|
568 |
|
569 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
570 |
eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
|