par-meta commited on
Commit
85c2f28
·
unverified ·
1 Parent(s): 9d907fe

Test first batch matches (#53)

Browse files
bytelatent/data/iterators/test_arrow_iterator.py CHANGED
@@ -28,6 +28,7 @@ def test_basic_arrow_file():
28
  row_num=0,
29
  arrow_batch_size=100,
30
  s3_profile=None,
 
31
  )
32
  arrow_file = initial_state.build()
33
  start_state = arrow_file.get_state()
@@ -57,6 +58,7 @@ def test_basic_arrow_file():
57
  row_num=251,
58
  arrow_batch_size=100,
59
  s3_profile=None,
 
60
  )
61
  arrow_file = resumed_state.build()
62
  for example in arrow_file.create_iter():
@@ -77,6 +79,7 @@ def test_basic_arrow_file():
77
  row_num=0,
78
  arrow_batch_size=100,
79
  s3_profile=None,
 
80
  )
81
  arrow_file = rank_state.build()
82
  expected_ids = []
 
28
  row_num=0,
29
  arrow_batch_size=100,
30
  s3_profile=None,
31
+ file_format="arrow",
32
  )
33
  arrow_file = initial_state.build()
34
  start_state = arrow_file.get_state()
 
58
  row_num=251,
59
  arrow_batch_size=100,
60
  s3_profile=None,
61
+ file_format="arrow",
62
  )
63
  arrow_file = resumed_state.build()
64
  for example in arrow_file.create_iter():
 
79
  row_num=0,
80
  arrow_batch_size=100,
81
  s3_profile=None,
82
+ file_format="arrow",
83
  )
84
  arrow_file = rank_state.build()
85
  expected_ids = []
bytelatent/data/test_data.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ import pytest
5
+ from omegaconf import OmegaConf
6
+
7
+ from bytelatent.args import TrainArgs
8
+ from bytelatent.constants import BLT_DATA
9
+
10
+
11
+ def get_test_config():
12
+ if "BLT_INTERNAL" in os.environ:
13
+ internal_dir = os.environ["BLT_INTERNAL"]
14
+ else:
15
+ internal_dir = "../internal-blt/configs"
16
+ test_config = os.path.join(internal_dir, "tests.yaml")
17
+ return test_config
18
+
19
+
20
+ @pytest.mark.skipif(
21
+ not os.path.exists(get_test_config()),
22
+ reason="Skipping since internal config is missing",
23
+ )
24
+ def test_first_batch_matches():
25
+ test_config_path = get_test_config()
26
+ default_cfg = OmegaConf.create(TrainArgs().model_dump())
27
+ file_cfg = OmegaConf.load(test_config_path)
28
+ merged_cfg = OmegaConf.merge(default_cfg, file_cfg)
29
+ merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True)
30
+ train_args = TrainArgs.model_validate(merged_cfg)
31
+ # MP doesn't work with async very well, but it doesn't change logic
32
+ train_args.data.load_async = False
33
+
34
+ # Test data created by pickling first batch in train loop then exiting
35
+ with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f:
36
+ first_batch = pickle.load(f)
37
+
38
+ # Emulate 1 node, 8 gpu training
39
+ data_loader = train_args.data.build_from_rank(0, 8)
40
+ batch_iterator = data_loader.create_iter()
41
+ print("Getting first batch")
42
+ batch = next(batch_iterator)
43
+ assert (batch.x == first_batch.x).all()
44
+ assert (batch.y == first_batch.y).all()
45
+ assert (batch.mask == first_batch.mask).all()
46
+ assert (batch.patch_lengths == first_batch.patch_lengths).all()
47
+ assert batch.ngram_ids is None and first_batch.ngram_ids is None
48
+ assert batch.is_final == False and batch.is_final == False
bytelatent/test_entropy_model.py CHANGED
@@ -25,6 +25,7 @@ def test_entropy_model():
25
  row_num=0,
26
  arrow_batch_size=100,
27
  s3_profile=None,
 
28
  )
29
  arrow_file = initial_state.build()
30
  tokenizer_args = TokenizerArgs(
 
25
  row_num=0,
26
  arrow_batch_size=100,
27
  s3_profile=None,
28
+ file_format="arrow",
29
  )
30
  arrow_file = initial_state.build()
31
  tokenizer_args = TokenizerArgs(
pyproject.toml CHANGED
@@ -2,4 +2,5 @@
2
  profile = "black"
3
  known_bytelatent = "bytelatent"
4
  known_apps = "apps"
 
5
  sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER"
 
2
  profile = "black"
3
  known_bytelatent = "bytelatent"
4
  known_apps = "apps"
5
+ known_third_party = "wandb"
6
  sections = "FUTURE,STDLIB,THIRDPARTY,BYTELATENT,APPS,FIRSTPARTY,LOCALFOLDER"