par-meta commited on
Commit
f84ee63
·
unverified ·
1 Parent(s): c110f6b

Update iterate_data (#81)

Browse files

Summary:

Test Plan:

Files changed (1) hide show
  1. bytelatent/iterate_data.py +15 -6
bytelatent/iterate_data.py CHANGED
@@ -8,10 +8,16 @@ from bytelatent.data.iterators.multiprocess_iterator import MultiprocessIterator
8
  from bytelatent.logger import init_logger
9
 
10
 
11
- def main(state_file: str):
 
 
 
 
 
 
12
  init_logger()
13
- pyarrow.set_io_thread_count(4)
14
- pyarrow.set_cpu_count(4)
15
  with open(state_file) as f:
16
  train_state = json.load(f)
17
  dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
@@ -20,10 +26,13 @@ def main(state_file: str):
20
  packing_iterator = packing_iterator_state.build()
21
  print("iter")
22
  batch_iter = packing_iterator.create_iter()
23
- batch = None
24
  print("looping")
25
- for i in track(range(1_000)):
26
- batch = next(batch_iter)
 
 
 
 
27
 
28
 
29
  if __name__ == "__main__":
 
8
  from bytelatent.logger import init_logger
9
 
10
 
11
+ def main(
12
+ state_file: str,
13
+ steps: int = 3_000,
14
+ io_thread_count: int = 2,
15
+ cpu_count: int = 2,
16
+ log_freq: int = 100,
17
+ ):
18
  init_logger()
19
+ pyarrow.set_io_thread_count(io_thread_count)
20
+ pyarrow.set_cpu_count(cpu_count)
21
  with open(state_file) as f:
22
  train_state = json.load(f)
23
  dl_state = MultiprocessIteratorState(**train_state["data_loader_state"])
 
26
  packing_iterator = packing_iterator_state.build()
27
  print("iter")
28
  batch_iter = packing_iterator.create_iter()
 
29
  print("looping")
30
+ for i in track(range(steps)):
31
+ _ = next(batch_iter)
32
+ if i % log_freq == 0:
33
+ print(pyarrow.default_memory_pool())
34
+ print(i)
35
+ print(pyarrow.default_memory_pool())
36
 
37
 
38
  if __name__ == "__main__":