par-meta commited on
Commit
1da3dd9
·
unverified ·
1 Parent(s): b0120da

Update preprocess_entropies script to blt inference + add fsspec support (#23)

Browse files
bytelatent/data/patcher.py CHANGED
@@ -82,16 +82,16 @@ def calculate_entropies(
82
  if device is not None:
83
  split = split.to(device)
84
  assert torch.all(split >= 0) and torch.all(split < 260)
85
- pred, _ = entropy_model(split)
86
  pred = pred.reshape(-1, pred.shape[-1])[
87
  : split.numel() - pad_size, :
88
  ] # [batch_size * seq_len, vocab]
89
  pred_entropies = entropy(pred)
90
  entropies.append(pred_entropies)
91
 
92
- entropies = torch.cat(entropies, dim=0)
93
- entropies = entropies.reshape(tokens.shape)
94
- return entropies
95
 
96
 
97
  def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
 
82
  if device is not None:
83
  split = split.to(device)
84
  assert torch.all(split >= 0) and torch.all(split < 260)
85
+ pred = entropy_model(split)
86
  pred = pred.reshape(-1, pred.shape[-1])[
87
  : split.numel() - pad_size, :
88
  ] # [batch_size * seq_len, vocab]
89
  pred_entropies = entropy(pred)
90
  entropies.append(pred_entropies)
91
 
92
+ concat_entropies = torch.cat(entropies, dim=0)
93
+ concat_entropies = concat_entropies.reshape(tokens.shape)
94
+ return concat_entropies
95
 
96
 
97
  def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
bytelatent/preprocess/preprocess_entropies.py CHANGED
@@ -1,14 +1,59 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  import time
3
- from pathlib import Path
4
 
 
 
5
  import numpy as np
6
  import pyarrow as pa
7
  import torch
8
  import typer
9
  from rich.progress import Progress, TextColumn
10
 
11
- from bytelatent.data.iterators.arrow_iterator import ArrowFileIterator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  def main(
@@ -16,39 +61,32 @@ def main(
16
  output_file: str,
17
  patching_device: str = "cuda",
18
  log_step: int = 10_000,
19
- entropy_model_checkpoint_dir: str = "entropy_checkpoint_dir",
 
 
20
  dry_run: bool = False,
 
21
  ):
22
- # TODO: Modify this to work with the new code
23
- raise NotImplementedError()
24
- iterator = ArrowFileIterator(
25
- file_path=input_file,
26
- worker_id=0,
27
- num_workers=1,
28
- )
29
- tokenization_mode = "bytes"
30
  print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
31
  print("Loading entropy model", entropy_model_checkpoint_dir)
 
 
 
32
  if dry_run:
33
  return
34
  entropy_model = load_entropy_model(
35
- entropy_model_checkpoint_dir, device=patching_device
 
 
36
  )
37
- entropy_model, _ = to_device(entropy_model, patching_device)
38
  print("Creating patcher")
39
  patching_batch_size = 32
40
  print("Creating tokenizer")
41
- tokenizer = Tokenizer(
42
- model_path="/home/artidoro/tokenizers/llama_v2.tokenizer.model",
43
- tokenization_mode=tokenization_mode,
44
- # BYTE_UNITS
45
- vocab_size_unit_1=256,
46
- bos=True,
47
- eos=True,
48
- bpe_delim=False,
49
- # This isn't used, just stores a reference for other calls we don't use
50
- patcher=None,
51
  )
 
52
  step = 0
53
  print("starting")
54
  start_time = time.time()
@@ -59,8 +97,10 @@ def main(
59
  schema = pa.schema([sample_id_field, text_field, entropy_field])
60
  arrow_batch_size = 1_000
61
 
 
 
62
  try:
63
- with pa.OSFile(output_file, "wb") as sink:
64
  with pa.ipc.new_file(sink, schema) as writer:
65
  id_buffer = []
66
  entropies_buffer = []
@@ -72,17 +112,9 @@ def main(
72
  task = progress.add_task(
73
  "[green]Calculating entropies...", total=None
74
  )
75
- for doc in iterator:
76
  sample_id = get_id_from_doc(doc)
77
-
78
- if "text" in doc:
79
- text = doc["text"]
80
- elif "content" in doc:
81
- text = doc["content"]
82
- else:
83
- raise ValueError(
84
- f"Could not find a text key from: {doc.keys()}"
85
- )
86
  tokens = torch.tensor(tokenizer.encode(text))
87
  patch_start = time.time()
88
  scores = calculate_entropies(
@@ -128,9 +160,10 @@ def main(
128
  entropies_buffer = []
129
  id_buffer = []
130
  text_buffer = []
131
- Path(f"{output_file}.complete").touch()
132
  except:
133
- Path(output_file).unlink(missing_ok=True)
 
134
  raise
135
  elapsed = time.time() - start_time
136
  print("steps", step)
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
  import time
 
3
 
4
+ import fsspec
5
+ import jsonlines
6
  import numpy as np
7
  import pyarrow as pa
8
  import torch
9
  import typer
10
  from rich.progress import Progress, TextColumn
11
 
12
+ from bytelatent.data.file_util import get_fs
13
+ from bytelatent.data.patcher import calculate_entropies
14
+ from bytelatent.entropy_model import load_entropy_model
15
+ from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
16
+
17
+
18
+ def get_id_from_doc(doc: dict) -> int:
19
+ """
20
+ We need a reliable way to ensure that samples from jsonl
21
+ and arrow are the same, but there is no unique id field,
22
+ so derive the best possible
23
+ """
24
+ if "sample_id" in doc:
25
+ sample_id = doc["sample_id"]
26
+ elif "title" in doc:
27
+ sample_id = doc["title"]
28
+ elif "qid" in doc:
29
+ sample_id = doc["qid"]
30
+ elif "paper_id" in doc:
31
+ sample_id = doc["paper_id"]
32
+ elif "path" in doc:
33
+ sample_id = doc["path"]
34
+ elif "url" in doc:
35
+ sample_id = doc["url"]
36
+ elif "id" in doc:
37
+ sample_id = doc["id"]
38
+ else:
39
+ raise ValueError(f"Could not find a id key from: {doc.keys()}")
40
+ return str(sample_id)
41
+
42
+
43
+ def get_text(doc: dict):
44
+ if "text" in doc:
45
+ text = doc["text"]
46
+ elif "content" in doc:
47
+ text = doc["content"]
48
+ else:
49
+ raise ValueError(f"Could not find a text key from: {doc.keys()}")
50
+ return text
51
+
52
+
53
+ def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
54
+ with fs.open(path) as f:
55
+ reader = jsonlines.Reader(f)
56
+ yield from reader
57
 
58
 
59
  def main(
 
61
  output_file: str,
62
  patching_device: str = "cuda",
63
  log_step: int = 10_000,
64
+ entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
65
+ entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
66
+ bpe_tokenizer_path: str = "public_data/tokenizer.model",
67
  dry_run: bool = False,
68
+ s3_profile: str | None = None,
69
  ):
 
 
 
 
 
 
 
 
70
  print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
71
  print("Loading entropy model", entropy_model_checkpoint_dir)
72
+ input_fs = get_fs(input_file, s3_profile=s3_profile)
73
+ input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
74
+
75
  if dry_run:
76
  return
77
  entropy_model = load_entropy_model(
78
+ entropy_model_checkpoint_dir,
79
+ entropy_model_state_dict_path,
80
+ device=patching_device,
81
  )
82
+
83
  print("Creating patcher")
84
  patching_batch_size = 32
85
  print("Creating tokenizer")
86
+ tokenizer_args = TokenizerArgs(
87
+ name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
 
 
 
 
 
 
 
 
88
  )
89
+ tokenizer = tokenizer_args.build()
90
  step = 0
91
  print("starting")
92
  start_time = time.time()
 
97
  schema = pa.schema([sample_id_field, text_field, entropy_field])
98
  arrow_batch_size = 1_000
99
 
100
+ output_fs = get_fs(output_file, s3_profile=s3_profile)
101
+
102
  try:
103
+ with output_fs.open(output_file, "wb") as sink:
104
  with pa.ipc.new_file(sink, schema) as writer:
105
  id_buffer = []
106
  entropies_buffer = []
 
112
  task = progress.add_task(
113
  "[green]Calculating entropies...", total=None
114
  )
115
+ for doc in input_doc_iterator:
116
  sample_id = get_id_from_doc(doc)
117
+ text = get_text(doc)
 
 
 
 
 
 
 
 
118
  tokens = torch.tensor(tokenizer.encode(text))
119
  patch_start = time.time()
120
  scores = calculate_entropies(
 
160
  entropies_buffer = []
161
  id_buffer = []
162
  text_buffer = []
163
+ output_fs.touch(f"{output_file}.complete")
164
  except:
165
+ if output_fs.exists(output_file):
166
+ output_fs.rm(output_file)
167
  raise
168
  elapsed = time.time() - start_time
169
  print("steps", step)
requirements.txt CHANGED
@@ -21,3 +21,4 @@ submitit
21
  typer
22
  rich
23
  fsspec[full]
 
 
21
  typer
22
  rich
23
  fsspec[full]
24
+ orjson