Spaces:
Running
on
Zero
Running
on
Zero
Update preprocess_entropies script to blt inference + add fsspec support (#23)
Browse files- bytelatent/data/patcher.py +4 -4
- bytelatent/preprocess/preprocess_entropies.py +69 -36
- requirements.txt +1 -0
bytelatent/data/patcher.py
CHANGED
@@ -82,16 +82,16 @@ def calculate_entropies(
|
|
82 |
if device is not None:
|
83 |
split = split.to(device)
|
84 |
assert torch.all(split >= 0) and torch.all(split < 260)
|
85 |
-
pred
|
86 |
pred = pred.reshape(-1, pred.shape[-1])[
|
87 |
: split.numel() - pad_size, :
|
88 |
] # [batch_size * seq_len, vocab]
|
89 |
pred_entropies = entropy(pred)
|
90 |
entropies.append(pred_entropies)
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
return
|
95 |
|
96 |
|
97 |
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
|
|
82 |
if device is not None:
|
83 |
split = split.to(device)
|
84 |
assert torch.all(split >= 0) and torch.all(split < 260)
|
85 |
+
pred = entropy_model(split)
|
86 |
pred = pred.reshape(-1, pred.shape[-1])[
|
87 |
: split.numel() - pad_size, :
|
88 |
] # [batch_size * seq_len, vocab]
|
89 |
pred_entropies = entropy(pred)
|
90 |
entropies.append(pred_entropies)
|
91 |
|
92 |
+
concat_entropies = torch.cat(entropies, dim=0)
|
93 |
+
concat_entropies = concat_entropies.reshape(tokens.shape)
|
94 |
+
return concat_entropies
|
95 |
|
96 |
|
97 |
def patch_start_mask_from_entropy_with_monotonicity(entropies, t):
|
bytelatent/preprocess/preprocess_entropies.py
CHANGED
@@ -1,14 +1,59 @@
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
import time
|
3 |
-
from pathlib import Path
|
4 |
|
|
|
|
|
5 |
import numpy as np
|
6 |
import pyarrow as pa
|
7 |
import torch
|
8 |
import typer
|
9 |
from rich.progress import Progress, TextColumn
|
10 |
|
11 |
-
from bytelatent.data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def main(
|
@@ -16,39 +61,32 @@ def main(
|
|
16 |
output_file: str,
|
17 |
patching_device: str = "cuda",
|
18 |
log_step: int = 10_000,
|
19 |
-
entropy_model_checkpoint_dir: str = "
|
|
|
|
|
20 |
dry_run: bool = False,
|
|
|
21 |
):
|
22 |
-
# TODO: Modify this to work with the new code
|
23 |
-
raise NotImplementedError()
|
24 |
-
iterator = ArrowFileIterator(
|
25 |
-
file_path=input_file,
|
26 |
-
worker_id=0,
|
27 |
-
num_workers=1,
|
28 |
-
)
|
29 |
-
tokenization_mode = "bytes"
|
30 |
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
31 |
print("Loading entropy model", entropy_model_checkpoint_dir)
|
|
|
|
|
|
|
32 |
if dry_run:
|
33 |
return
|
34 |
entropy_model = load_entropy_model(
|
35 |
-
entropy_model_checkpoint_dir,
|
|
|
|
|
36 |
)
|
37 |
-
|
38 |
print("Creating patcher")
|
39 |
patching_batch_size = 32
|
40 |
print("Creating tokenizer")
|
41 |
-
|
42 |
-
|
43 |
-
tokenization_mode=tokenization_mode,
|
44 |
-
# BYTE_UNITS
|
45 |
-
vocab_size_unit_1=256,
|
46 |
-
bos=True,
|
47 |
-
eos=True,
|
48 |
-
bpe_delim=False,
|
49 |
-
# This isn't used, just stores a reference for other calls we don't use
|
50 |
-
patcher=None,
|
51 |
)
|
|
|
52 |
step = 0
|
53 |
print("starting")
|
54 |
start_time = time.time()
|
@@ -59,8 +97,10 @@ def main(
|
|
59 |
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
60 |
arrow_batch_size = 1_000
|
61 |
|
|
|
|
|
62 |
try:
|
63 |
-
with
|
64 |
with pa.ipc.new_file(sink, schema) as writer:
|
65 |
id_buffer = []
|
66 |
entropies_buffer = []
|
@@ -72,17 +112,9 @@ def main(
|
|
72 |
task = progress.add_task(
|
73 |
"[green]Calculating entropies...", total=None
|
74 |
)
|
75 |
-
for doc in
|
76 |
sample_id = get_id_from_doc(doc)
|
77 |
-
|
78 |
-
if "text" in doc:
|
79 |
-
text = doc["text"]
|
80 |
-
elif "content" in doc:
|
81 |
-
text = doc["content"]
|
82 |
-
else:
|
83 |
-
raise ValueError(
|
84 |
-
f"Could not find a text key from: {doc.keys()}"
|
85 |
-
)
|
86 |
tokens = torch.tensor(tokenizer.encode(text))
|
87 |
patch_start = time.time()
|
88 |
scores = calculate_entropies(
|
@@ -128,9 +160,10 @@ def main(
|
|
128 |
entropies_buffer = []
|
129 |
id_buffer = []
|
130 |
text_buffer = []
|
131 |
-
|
132 |
except:
|
133 |
-
|
|
|
134 |
raise
|
135 |
elapsed = time.time() - start_time
|
136 |
print("steps", step)
|
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
import time
|
|
|
3 |
|
4 |
+
import fsspec
|
5 |
+
import jsonlines
|
6 |
import numpy as np
|
7 |
import pyarrow as pa
|
8 |
import torch
|
9 |
import typer
|
10 |
from rich.progress import Progress, TextColumn
|
11 |
|
12 |
+
from bytelatent.data.file_util import get_fs
|
13 |
+
from bytelatent.data.patcher import calculate_entropies
|
14 |
+
from bytelatent.entropy_model import load_entropy_model
|
15 |
+
from bytelatent.tokenizers.build_tokenizer import TokenizerArgs
|
16 |
+
|
17 |
+
|
18 |
+
def get_id_from_doc(doc: dict) -> int:
|
19 |
+
"""
|
20 |
+
We need a reliable way to ensure that samples from jsonl
|
21 |
+
and arrow are the same, but there is no unique id field,
|
22 |
+
so derive the best possible
|
23 |
+
"""
|
24 |
+
if "sample_id" in doc:
|
25 |
+
sample_id = doc["sample_id"]
|
26 |
+
elif "title" in doc:
|
27 |
+
sample_id = doc["title"]
|
28 |
+
elif "qid" in doc:
|
29 |
+
sample_id = doc["qid"]
|
30 |
+
elif "paper_id" in doc:
|
31 |
+
sample_id = doc["paper_id"]
|
32 |
+
elif "path" in doc:
|
33 |
+
sample_id = doc["path"]
|
34 |
+
elif "url" in doc:
|
35 |
+
sample_id = doc["url"]
|
36 |
+
elif "id" in doc:
|
37 |
+
sample_id = doc["id"]
|
38 |
+
else:
|
39 |
+
raise ValueError(f"Could not find a id key from: {doc.keys()}")
|
40 |
+
return str(sample_id)
|
41 |
+
|
42 |
+
|
43 |
+
def get_text(doc: dict):
|
44 |
+
if "text" in doc:
|
45 |
+
text = doc["text"]
|
46 |
+
elif "content" in doc:
|
47 |
+
text = doc["content"]
|
48 |
+
else:
|
49 |
+
raise ValueError(f"Could not find a text key from: {doc.keys()}")
|
50 |
+
return text
|
51 |
+
|
52 |
+
|
53 |
+
def jsonl_file_iterator(fs: fsspec.AbstractFileSystem, path: str):
|
54 |
+
with fs.open(path) as f:
|
55 |
+
reader = jsonlines.Reader(f)
|
56 |
+
yield from reader
|
57 |
|
58 |
|
59 |
def main(
|
|
|
61 |
output_file: str,
|
62 |
patching_device: str = "cuda",
|
63 |
log_step: int = 10_000,
|
64 |
+
entropy_model_checkpoint_dir: str = "public_data/entropy_checkpoint",
|
65 |
+
entropy_model_state_dict_path: str = "public_data/entropy_model.pth",
|
66 |
+
bpe_tokenizer_path: str = "public_data/tokenizer.model",
|
67 |
dry_run: bool = False,
|
68 |
+
s3_profile: str | None = None,
|
69 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
print(f"Preprocessing entropies, input: {input_file}, output: {output_file}")
|
71 |
print("Loading entropy model", entropy_model_checkpoint_dir)
|
72 |
+
input_fs = get_fs(input_file, s3_profile=s3_profile)
|
73 |
+
input_doc_iterator = jsonl_file_iterator(input_fs, input_file)
|
74 |
+
|
75 |
if dry_run:
|
76 |
return
|
77 |
entropy_model = load_entropy_model(
|
78 |
+
entropy_model_checkpoint_dir,
|
79 |
+
entropy_model_state_dict_path,
|
80 |
+
device=patching_device,
|
81 |
)
|
82 |
+
|
83 |
print("Creating patcher")
|
84 |
patching_batch_size = 32
|
85 |
print("Creating tokenizer")
|
86 |
+
tokenizer_args = TokenizerArgs(
|
87 |
+
name="blt", init_kwargs={"bpe_tokenizer_path": bpe_tokenizer_path}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
)
|
89 |
+
tokenizer = tokenizer_args.build()
|
90 |
step = 0
|
91 |
print("starting")
|
92 |
start_time = time.time()
|
|
|
97 |
schema = pa.schema([sample_id_field, text_field, entropy_field])
|
98 |
arrow_batch_size = 1_000
|
99 |
|
100 |
+
output_fs = get_fs(output_file, s3_profile=s3_profile)
|
101 |
+
|
102 |
try:
|
103 |
+
with output_fs.open(output_file, "wb") as sink:
|
104 |
with pa.ipc.new_file(sink, schema) as writer:
|
105 |
id_buffer = []
|
106 |
entropies_buffer = []
|
|
|
112 |
task = progress.add_task(
|
113 |
"[green]Calculating entropies...", total=None
|
114 |
)
|
115 |
+
for doc in input_doc_iterator:
|
116 |
sample_id = get_id_from_doc(doc)
|
117 |
+
text = get_text(doc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
tokens = torch.tensor(tokenizer.encode(text))
|
119 |
patch_start = time.time()
|
120 |
scores = calculate_entropies(
|
|
|
160 |
entropies_buffer = []
|
161 |
id_buffer = []
|
162 |
text_buffer = []
|
163 |
+
output_fs.touch(f"{output_file}.complete")
|
164 |
except:
|
165 |
+
if output_fs.exists(output_file):
|
166 |
+
output_fs.rm(output_file)
|
167 |
raise
|
168 |
elapsed = time.time() - start_time
|
169 |
print("steps", step)
|
requirements.txt
CHANGED
@@ -21,3 +21,4 @@ submitit
|
|
21 |
typer
|
22 |
rich
|
23 |
fsspec[full]
|
|
|
|
21 |
typer
|
22 |
rich
|
23 |
fsspec[full]
|
24 |
+
orjson
|