|
import os |
|
import shutil |
|
import signal |
|
import sys |
|
import time |
|
from contextlib import ExitStack |
|
from functools import partial |
|
from pathlib import Path |
|
|
|
from accelerate.utils import gather_object |
|
from torchinfo import summary |
|
|
|
from unidisc.tokenizers.chameleon_tokenizers import tokenize_chameleon |
|
from utils import _print_config, set_numa_affinity, set_omega_conf_resolvers |
|
|
|
sys.path.append(str(Path(__file__).parent.parent.parent / "unidisc/misc/hydra_submitit_launcher")) |
|
import itertools |
|
import json |
|
import os |
|
import random |
|
import sys |
|
from contextlib import nullcontext |
|
from pathlib import Path |
|
|
|
import fsspec |
|
import hydra |
|
import numpy as np |
|
import omegaconf |
|
import rich.syntax |
|
import rich.tree |
|
import torch |
|
from accelerate import Accelerator |
|
from PIL import Image |
|
from tensordict import TensorDict |
|
from tqdm import tqdm |
|
from viztracer import VizTracer |
|
|
|
from dataloader import get_dataloaders, get_tokenizer, tokenize_text |
|
from decoupled_utils import (barrier, breakpoint_on_error, get_world_size, |
|
is_local_main_process, is_main_process, |
|
rank_zero_fn, rprint, set_global_breakpoint, |
|
set_global_exists, gprint) |
|
from model import decode_latents, get_image_batch, get_vae |
|
from models.datasets.combine_token_dicts import main as combine_token_dicts |
|
from models.datasets.vggface_v2_attributes import (get_inference_func, |
|
get_output) |
|
from utils import (_print_config, set_numa_affinity, set_omega_conf_resolvers, |
|
set_torch_defaults) |
|
|
|
os.environ["HYDRA_FULL_ERROR"] = "1" |
|
|
|
set_global_breakpoint() |
|
set_global_exists() |
|
set_omega_conf_resolvers() |
|
set_torch_defaults() |
|
|
|
def get_dict(config, dataset_size): |
|
data = TensorDict( |
|
{ |
|
"input_ids": torch.zeros(dataset_size, config.model.img_length, dtype=torch.int16), |
|
"idx": torch.full((dataset_size, 1), fill_value=-1, dtype=torch.int32), |
|
"write_flag": torch.zeros(dataset_size, 1, dtype=torch.bool), |
|
"modality": torch.full((dataset_size, 1), fill_value=-1, dtype=torch.int16), |
|
}, |
|
batch_size=[dataset_size], |
|
) |
|
return data |
|
|
|
def _group_texts(examples, block_size, bos, eos): |
|
|
|
concatenated_examples = list(itertools.chain(* examples['input_ids'])) |
|
total_length = len(concatenated_examples) |
|
|
|
|
|
|
|
|
|
|
|
new_block_size = block_size |
|
total_length = (total_length // new_block_size) * new_block_size |
|
|
|
result = {} |
|
_values = [] |
|
_attn_masks = [] |
|
for i in range(0, total_length, new_block_size): |
|
_data = concatenated_examples[i : i + new_block_size] |
|
_data[0] = bos |
|
_data[-1] = eos |
|
_values.append(_data) |
|
|
|
result['input_ids'] = _values |
|
|
|
|
|
|
|
|
|
|
|
return result |
|
|
|
def preprocess_and_tokenize(example, tokenizer, dataset_name, wrap, block_size, EOS, BOS): |
|
if dataset_name == 'ptb': |
|
text = example['sentence'] |
|
elif 'scientific_papers' in dataset_name: |
|
text = example['article'] |
|
else: |
|
text = example['text'] |
|
|
|
tokenizer.padding_side = 'right' |
|
tokenizer.truncation_side = 'right' |
|
|
|
if wrap: |
|
tokens = tokenizer(text, |
|
add_special_tokens=True, |
|
return_attention_mask=False, |
|
return_token_type_ids=False) |
|
tokens = {'input_ids': tokens['input_ids']} |
|
|
|
else: |
|
tokens = tokenizer(text, |
|
max_length=block_size, |
|
padding='max_length', |
|
truncation=True, |
|
add_special_tokens=True, |
|
return_attention_mask=True, |
|
return_token_type_ids=True) |
|
return tokens |
|
|
|
def add_modality(output_dataset): |
|
modality_column = torch.zeros(len(output_dataset), 1, dtype=torch.long) |
|
output_dataset = output_dataset.add_column("modality", modality_column) |
|
|
|
import datasets |
|
@hydra.main(version_base=None, config_path="../../configs", config_name="config") |
|
def main(config): |
|
"""Main entry point for training.""" |
|
_print_config(config, resolve=True, save_cfg=True) |
|
tokenizer = get_tokenizer(config) |
|
block_size = config.data.block_size |
|
|
|
wrap = True |
|
streaming = config.data.streaming |
|
num_proc = config.data.num_proc |
|
split = getattr(config.data, "split", "train") |
|
use_cache = False |
|
|
|
assert getattr(config.data, "use_slow_tokenizer", False) is False |
|
|
|
output_dir = config.data.token_output_dir |
|
output_dir = Path(f"{output_dir}") |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
tensordict_output_dir = output_dir.parent / f"{output_dir.stem}_tensordict" |
|
tensordict_output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
dataset_name = config.data.train |
|
if isinstance(dataset_name, list): |
|
data = datasets.concatenate_datasets([ |
|
datasets.load_dataset(name, split=split, cache_dir=config.data.cache_dir, streaming=streaming) |
|
for name in dataset_name |
|
]) |
|
else: |
|
_args = [] |
|
if getattr(config.data, "add_load_dataset_args", None) is not None: |
|
_args.append(getattr(config.data, "add_load_dataset_args", None)) |
|
data = datasets.load_dataset(dataset_name, *_args, split=split, cache_dir=config.data.cache_dir, streaming=streaming) |
|
|
|
|
|
EOS = tokenizer.eos_token_id |
|
BOS = tokenizer.bos_token_id |
|
|
|
if config.data.n_train_samples is not None: |
|
print(f"Selecting {config.data.n_train_samples} samples") |
|
data = data.select(range(config.data.n_train_samples)) |
|
|
|
_preprocess_and_tokenize = partial(preprocess_and_tokenize, tokenizer=tokenizer, dataset_name=dataset_name, wrap=wrap, block_size=block_size, EOS=EOS, BOS=BOS) |
|
if streaming: |
|
tokenized_dataset = data.map( |
|
_preprocess_and_tokenize, |
|
batched=True |
|
) |
|
else: |
|
rprint(f"Tokenizing with num_proc: {num_proc}") |
|
tokenized_dataset = data.map( |
|
_preprocess_and_tokenize, |
|
batched=True, |
|
num_proc=num_proc, |
|
load_from_cache_file=use_cache, |
|
desc='Tokenizing') |
|
|
|
tokenized_dataset = tokenized_dataset.remove_columns('text') |
|
columns_to_keep = ['input_ids'] |
|
if tokenized_dataset.column_names is not None: |
|
columns_to_remove = [col for col in tokenized_dataset.column_names if col not in columns_to_keep] |
|
tokenized_dataset = tokenized_dataset.remove_columns(columns_to_remove) |
|
|
|
output_dataset = None |
|
if wrap: |
|
group_texts = partial(_group_texts, block_size=block_size, bos=BOS, eos=EOS) |
|
if streaming: |
|
chunked_dataset = tokenized_dataset.map(group_texts, batched=True) |
|
else: |
|
chunked_dataset = tokenized_dataset.map(group_texts, batched=True, num_proc=num_proc, load_from_cache_file=use_cache, desc='Grouping') |
|
chunked_dataset.save_to_disk(output_dir) |
|
|
|
output_dataset = chunked_dataset.with_format('torch') |
|
else: |
|
if streaming is False: |
|
tokenized_dataset.save_to_disk(output_dir) |
|
output_dataset = tokenized_dataset.with_format('torch') |
|
|
|
if __name__ == "__main__": |
|
with breakpoint_on_error(): |
|
main() |
|
|