unidisc / dataloader.py
aswerdlow's picture
Initial commit
131da64
import math
import typing
from pathlib import Path
import tokenizers
import torch
import transformers
from unidisc.datasets.sampler import WeightedDatasetSampler
from models.datasets.image_datasets import TensorCollate, get_image_dataset, get_unpaired_dataset
from models.datasets.text_datasets import Text8Tokenizer, get_text_dataset
from torch.utils.data import default_collate
from decoupled_utils import breakpoint_on_error, gprint, rprint, is_torch_xla_available
from datasets import load_dataset
def identity(x):
return x
def get_dataset(dataset_name, tokenizer, *args, config=None, **kwargs):
rprint(f"getting dataset {dataset_name}")
if getattr(config.data, "unpaired", False):
return get_unpaired_dataset(config=config, tokenizer=tokenizer, **kwargs)
elif getattr(config.model, "image_model", False) or getattr(config.data, "force_image_dataset", False):
return get_image_dataset(config=config, tokenizer=tokenizer, **kwargs)
else:
rprint(f"getting text dataset")
return get_text_dataset(dataset_name, tokenizer, *args, **kwargs)
def tokenize_text(tokenizer, block_size, text, return_token_type_ids=True):
return tokenizer(text, max_length=block_size, padding="max_length", truncation=True, add_special_tokens=True, return_attention_mask=True, return_token_type_ids=return_token_type_ids).convert_to_tensors("pt")
def get_tokenizer(config):
if config.data.tokenizer_name_or_path is None or config.data.tokenizer_name_or_path == "None":
return None
elif config.data.tokenizer_name_or_path == "text8":
tokenizer = Text8Tokenizer()
elif config.data.tokenizer_name_or_path == "bert-base-uncased":
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
else:
tokenizer_kwargs = dict()
if config.data.tokenizer_name_or_path == "NousResearch/Llama-2-7b-hf":
tokenizer_kwargs["add_eos_token"] = True
tokenizer_kwargs["padding_side"] = 'right'
rprint("Using Llama tokenizer, adding add_eos_token and setting padding_side to right")
if getattr(config.data, "use_slow_tokenizer", False):
tokenizer_kwargs["use_fast"] = False
tokenizer = transformers.AutoTokenizer.from_pretrained(config.data.tokenizer_name_or_path, **tokenizer_kwargs)
if getattr(config.data, "add_image_token", False):
special_token = '<image>'
existing_id = 811
tmp_index = len(tokenizer)
tokenizer.add_special_tokens({
'additional_special_tokens': [special_token]
}, replace_additional_special_tokens=False)
tokenizer._added_tokens_decoder[existing_id] = tokenizer._added_tokens_decoder.pop(tmp_index)
assert len(tokenizer.additional_special_tokens_ids) == 1
tokenizer.additional_special_tokens_ids = [existing_id]
tokenizer._added_tokens_encoder['<image>'] = existing_id
tokenizer.total_vocab_size = tmp_index
if isinstance(tokenizer, transformers.GPT2TokenizerFast) or isinstance(tokenizer, transformers.GPT2Tokenizer):
tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing(
(tokenizer.bos_token, tokenizer.bos_token_id), (tokenizer.eos_token, tokenizer.eos_token_id)
)
# For wrapped batches:
# [BOS] sent1 [EOS] sent2-fragment [EOS]
# [BOS] sent2-fragment [EOS] sent3 [EOS]
if tokenizer.bos_token is None:
if tokenizer.cls_token is None:
raise AttributeError("Tokenizer must have a bos_token or " f"cls_token: {tokenizer}")
tokenizer.bos_token = tokenizer.cls_token
if tokenizer.eos_token is None:
if tokenizer.sep_token is None:
raise AttributeError("Tokenizer must have a eos_token " f"or sep_token: {tokenizer}")
tokenizer.eos_token = tokenizer.sep_token
if tokenizer.pad_token is None:
if config.data.tokenizer_name_or_path != "gpt2":
rprint(f"Adding pad token to tokenizer")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
assert tokenizer.padding_side == 'right'
assert tokenizer.truncation_side == 'right'
return tokenizer
class SimpleDataLoader:
def __init__(self, dataset, batch_size=1, collate_fn=default_collate, **kwargs):
self.dataset = dataset
self.batch_size = batch_size
self.collate_fn = collate_fn
self.idx = 0
def __iter__(self):
return self
def __next__(self):
if self.idx < len(self.dataset):
batch = []
for _ in range(self.batch_size):
if self.idx >= len(self.dataset):
break
batch.append(self.dataset[self.idx])
self.idx += 1
return self.collate_fn(batch)
else:
raise StopIteration
def __len__(self):
return (len(self.dataset) + self.batch_size - 1) // self.batch_size
def get_zero_shot_dataloader(config, tokenizer, device=None, **kwargs):
if config.data.zero_shot_eval_dataset is None:
rprint("No zero shot eval dataset provided")
return None, None
dataset_name = config.data.zero_shot_eval_dataset
dataloader_seed = config.seed if config.mode == "eval" else 42
if dataset_name == "nlphuji/flickr30k":
data = load_dataset(dataset_name, num_proc=config.data.num_proc, cache_dir=config.data.cache_dir, streaming=config.data.streaming)
dataset = data["test"]
elif dataset_name == "facebook/winoground":
data = load_dataset(dataset_name, num_proc=config.data.num_proc, cache_dir=config.data.cache_dir, streaming=config.data.streaming)
dataset = data["test"]
breakpoint()
dl_cls = torch.utils.data.DataLoader
valid_loader = dl_cls(
dataset,
batch_size=config.loader.eval_batch_size,
num_workers=config.loader.num_eval_workers,
pin_memory=config.loader.pin_memory,
generator=torch.Generator().manual_seed(dataloader_seed),
persistent_workers=False,
**kwargs,
)
valid_loader.tokenizer = tokenizer
return valid_loader
def get_dataloaders(config, tokenizer, skip_train=False, skip_valid=False, valid_seed=None, device=None, **kwargs):
if skip_train:
train_set = None
else:
_mode = getattr(config.data, "force_train_mode", "train")
if _mode != "train":
rprint(f"Forcing train mode to {_mode}")
train_set = get_dataset(
config.data.train,
tokenizer,
mode=_mode,
wrap=config.data.wrap,
cache_dir=config.data.cache_dir,
block_size=config.model.length,
num_proc=config.data.num_proc,
streaming=config.data.streaming,
config=config,
**kwargs,
)
if hasattr(train_set, '__len__'):
rprint(f"Training set len: {len(train_set)}")
if config.data.valid in ["text8", "lm1b", "ag_news"]:
validation_split = "test"
else:
validation_split = "validation"
if skip_valid:
valid_set = None
else:
valid_set = get_dataset(
config.data.valid,
tokenizer,
wrap=config.data.wrap,
mode=validation_split,
cache_dir=config.data.cache_dir,
block_size=config.model.length,
streaming=False,
num_proc=config.data.num_proc,
config=config,
**kwargs,
)
if hasattr(valid_set, '__len__'):
rprint(f"Validation set len: {len(valid_set)}")
dataloader_seed = config.seed if (config.mode == "eval" or is_torch_xla_available() or getattr(config.data, "force_seed", False)) else 42
gprint(f"Dataloader seed: {dataloader_seed}")
if skip_train:
train_loader = None
else:
train_kwargs = dict(drop_last=True)
train_dataloader_generator = torch.Generator().manual_seed(dataloader_seed)
dl_cls = torch.utils.data.DataLoader
if getattr(config.data, "webdataset_iterable", False) or getattr(config.data, "webdataset_indexed", False):
train_kwargs.pop("drop_last", None)
if getattr(config.loader, "disable_prefetch", False):
train_kwargs["prefetch_factor"] = 1
if getattr(config.data, "force_disable_shuffle", False) is False:
if getattr(config.data, "webdataset_iterable", False):
import webdataset
dl_cls = webdataset.WebLoader
train_kwargs["shuffle"] = False
train_kwargs["prefetch_factor"] = 8
elif getattr(config.data, "webdataset_indexed", False):
import wids
train_kwargs["sampler"] = wids.DistributedChunkedSampler(train_set, shuffle=True)
elif isinstance(train_set, torch.utils.data.IterableDataset) is False:
train_kwargs["shuffle"] = True
if "tokens" in config.data.train and config.data.pin_dataset_to_gpu:
if config.backend == 'cuda':
cur_mb = torch.cuda.memory_reserved() / 1e9
rprint(f"Moving dataloader to device {device} with: {cur_mb} GB of memory reserved")
train_set = train_set.to(device=device)
if config.backend == 'cuda':
cur_mb = torch.cuda.memory_reserved() / 1e9
rprint(f"Moved dataloader to device {device} with: {cur_mb} GB of memory reserved")
if "tokens" in config.data.train:
if getattr(config.data, "use_custom_tensordict_collate", False):
train_kwargs["collate_fn"] = TensorCollate(device=device, enable_cuda_in_tensordict_collate=config.data.enable_cuda_in_tensordict_collate)
else:
train_kwargs["collate_fn"] = identity
if getattr(config.data, "use_packing_collate", False):
generator = torch.Generator().manual_seed(dataloader_seed)
token_collate = train_kwargs["collate_fn"] if getattr(config.data, "use_custom_tensordict_collate", False) else None
train_kwargs["collate_fn"] = PackingCollate(config, train_set, config.model.length, generator, tensor_collate=token_collate, tokenizer=tokenizer)
if getattr(config.data, "use_weighted_tensordict_sampler", False):
generator = torch.Generator().manual_seed(dataloader_seed)
train_kwargs['sampler'] = WeightedDatasetSampler(train_set, generator=generator)
train_kwargs["shuffle"] = False
else:
train_kwargs["shuffle"] = True
if getattr(config.data, "use_list_collate", False):
train_kwargs["collate_fn"] = lambda x: x
if getattr(config.data, "force_shuffle_train", False):
rprint("Forcing shuffle on train dataloader")
train_kwargs["shuffle"] = True
if getattr(config.data, "force_disable_shuffle_train", False):
rprint("Forcing disable shuffle on train dataloader")
train_kwargs["shuffle"] = False
if getattr(config.data, "force_distributed_sampler", False):
import torch_xla.runtime as xr
train_kwargs["sampler"] = torch.utils.data.distributed.DistributedSampler(
train_set,
num_replicas=xr.world_size(),
rank=xr.global_ordinal(),
shuffle=True
)
if getattr(config.data, "use_identity_collate", False):
train_kwargs["collate_fn"] = lambda x: x
if train_set.__class__.__name__ == "WebLoader":
train_loader = train_set
else:
rprint(f"Train dataloader kwargs: {train_kwargs}")
train_loader = dl_cls(
train_set,
batch_size=None if getattr(config.data, "webdataset_iterable", False) else config.loader.batch_size,
num_workers=config.loader.num_workers,
pin_memory=config.loader.pin_memory,
persistent_workers=config.loader.num_workers > 0 and getattr(config.loader, "persistent_workers", True),
generator=train_dataloader_generator,
**train_kwargs,
)
train_loader.tokenizer = tokenizer
if skip_valid:
valid_loader = None
else:
shuffle_valid = True
valid_dataloader_generator = torch.Generator().manual_seed(dataloader_seed)
valid_kwargs = dict(drop_last=True)
dl_cls = torch.utils.data.DataLoader
if getattr(config.data, "webdataset_iterable", False) or getattr(config.data, "webdataset_indexed", False):
valid_kwargs.pop("drop_last", None)
if getattr(config.data, "force_disable_shuffle", False) is False:
if getattr(config.data, "webdataset_iterable", False):
valid_kwargs["shuffle"] = False
import webdataset
dl_cls = webdataset.WebLoader
elif getattr(config.data, "webdataset_indexed", False):
import wids
valid_kwargs["sampler"] = wids.DistributedChunkedSampler(valid_set, shuffle=True)
elif isinstance(valid_set, torch.utils.data.IterableDataset) is False and shuffle_valid:
valid_kwargs["shuffle"] = shuffle_valid
if "tokens" in config.data.valid:
if getattr(config.data, "use_custom_tensordict_collate", False):
valid_kwargs["collate_fn"] = TensorCollate(device=device, enable_cuda_in_tensordict_collate=config.data.enable_cuda_in_tensordict_collate)
else:
valid_kwargs["collate_fn"] = identity
if getattr(config.data, "use_packing_collate", False):
generator = torch.Generator().manual_seed(dataloader_seed)
token_collate = valid_kwargs["collate_fn"] if getattr(config.data, "use_custom_tensordict_collate", False) else None
valid_kwargs["collate_fn"] = PackingCollate(config, valid_set, config.model.length, generator, tensor_collate=token_collate, tokenizer=tokenizer)
if getattr(config.data, "use_weighted_tensordict_sampler", False):
generator = torch.Generator().manual_seed(dataloader_seed)
valid_kwargs['sampler'] = WeightedDatasetSampler(valid_set, generator=generator)
if getattr(config.data, "shuffle_valid", False):
torch.manual_seed(config.seed)
valid_kwargs["shuffle"] = getattr(config.data, "shuffle_valid", False)
if getattr(config.data, "force_distributed_sampler", False):
import torch_xla.runtime as xr
valid_kwargs["sampler"] = torch.utils.data.distributed.DistributedSampler(
valid_set,
num_replicas=xr.world_size(),
rank=xr.global_ordinal(),
shuffle=True
)
if valid_set.__class__.__name__ == "WebLoader":
valid_loader = valid_set
else:
rprint(f"Valid dataloader kwargs: {valid_kwargs}")
valid_loader = dl_cls(
valid_set,
batch_size=None if getattr(config.data, "webdataset_iterable", False) else config.loader.eval_batch_size,
num_workers=getattr(config.loader, "num_eval_workers", config.loader.num_workers),
pin_memory=config.loader.pin_memory,
generator=valid_dataloader_generator,
persistent_workers=False,
**valid_kwargs,
)
# Will be used in generative perplexity calculation
valid_loader.tokenizer = tokenizer
return train_loader, valid_loader
# Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py
class RandomFaultTolerantSampler(torch.utils.data.RandomSampler):
def __init__(self, *args, generator=None, **kwargs):
# TD [2022-07-17]: We don't force the seed to be zero. We generate random seed,
# which should be reproducible if pl.seed_everything was called beforehand.
# This means that changing the seed of the experiment will also change the
# sampling order.
if generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator().manual_seed(seed)
kwargs.pop("shuffle", None)
super().__init__(*args, generator=generator, **kwargs)
self.counter = 0
self.restarting = False
def state_dict(self):
return {"random_state": self.generator.get_state(), "counter": self.counter}
def load_state_dict(self, state_dict):
self.generator.set_state(state_dict.get("random_state"))
self.counter = state_dict["counter"]
# self.start_counter = self.counter
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
def __iter__(self) -> typing.Iterator[int]:
n = len(self.data_source)
self.state = self.generator.get_state()
indices = torch.randperm(n, generator=self.generator).tolist()
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter :]
self.restarting = False
for index in indices:
self.counter += 1
yield index
self.counter = 0
class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.counter = 0
self.restarting = False
def state_dict(self):
return {"epoch": self.epoch, "counter": self.counter}
def load_state_dict(self, state_dict):
self.epoch = state_dict["epoch"]
self.counter = state_dict["counter"]
self.restarting = True
# TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per
# epoch, and subsequent epoch will have very few batches.
def __iter__(self):
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
if not self.restarting:
self.counter = 0
else:
indices = indices[self.counter :]
self.restarting = False
for index in indices:
self.counter += 1
yield index
self.counter = 0
if __name__ == "__main__":
import os
with breakpoint_on_error():
from omegaconf import OmegaConf
cc12m_config = OmegaConf.create(
{
"model": {
"image_model": True,
"unified_model": True,
},
"data": {
"tokenizers_parallelism": False,
"resolution": 128,
"train": "pixparse/cc12m-wds",
"val": "pixparse/cc12m-wds",
"streaming": False,
"precache": True,
"tokenizer_name_or_path": "gpt2",
"n_val_samples": None,
"n_train_samples": None,
"block_size": 32,
"data_dir": "/path/to/cc12m",
},
}
)
imagenet_config = OmegaConf.create(
{
"model": {
"image_model": True,
},
"data": {
"resolution": 128,
"train": "ILSVRC/imagenet-1k",
"val": "ILSVRC/imagenet-1k",
"streaming": False,
"precache": True,
"tokenizer_name_or_path": "gpt2",
},
}
)
facecaption_config = OmegaConf.create(
{
"seed": 12345,
"model": {
"image_model": True,
},
"data": {
"resolution": 256,
"train": "facecaption",
"val": "facecaption",
"streaming": False,
"precache": False,
"tokenizer_name_or_path": "gpt2",
"cache_dir": os.environ["HF_DATASETS_CACHE"],
"raw_data_dir": "/grogu/user/mprabhud/data/diffusion/facecaption",
"block_size": 32,
},
"loader": {
"num_workers": 0,
"batch_size": 1,
"eval_batch_size": 1,
},
"trainer": {
"devices": 1,
"num_nodes": 1,
"accumulate_grad_batches": 1,
},
}
)
tokenizer = get_tokenizer(facecaption_config)
dataset = get_dataset(
dataset_name=facecaption_config.data.train,
mode="train",
config=facecaption_config,
tokenizer=tokenizer,
)
test = next(iter(dataset))
breakpoint()
from typing import List, Dict
import torch
from tensordict import TensorDict
def process_batch(batch: TensorDict):
if isinstance(batch, list):
return [process_batch(b) for b in batch]
else:
if "write_flag" in batch:
del batch["write_flag"]
if "dataset_idx" in batch:
del batch["dataset_idx"]
batch.auto_batch_size_()
return batch
def ignore_slice(tensor, slice, padding_token_id):
tensor["modality"][slice] = -1
tensor["attention_mask"][slice] = 0
tensor["input_ids"][slice] = padding_token_id
if "sample_ids" in tensor:
tensor["sample_ids"][slice] = -1
else:
tensor["sample_ids"] = torch.full(tensor["input_ids"].shape, fill_value=-1, dtype=tensor["input_ids"].dtype, device=tensor["input_ids"].device)
class PackingCollate:
def __init__(self, config, dataset, seq_length, generator, tensor_collate=None, tokenizer=None):
self.dataset = dataset
self.seq_length = seq_length
self.tensor_collate = tensor_collate
self.generator = generator
self.tokenizer = tokenizer
self.padding_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id
self.disable_packing = getattr(config.data, "disable_packing", False)
img_special_tokens = tokenizer("<image>", add_special_tokens=False)['input_ids']
assert len(img_special_tokens) == 1
self.image_token_id = img_special_tokens[0]
def __call__(self, batch: TensorDict):
if self.tensor_collate is not None:
if isinstance(batch, list):
batch = [self.tensor_collate(b) for b in batch]
else:
batch = self.tensor_collate(batch)
B = len(batch)
seq_length = self.seq_length
batch = process_batch(batch)
assert batch[0].batch_size is None or len(batch[0].batch_size) == 1
new_batch = batch[0].new_zeros((B, seq_length))
ignore_slice(new_batch, slice(None, None), self.padding_token_id)
for i in range(B):
total_length = 0
sample_idx = 0
sample_queue = [batch[i]]
# We originally get bs number of samples but since we're packing, we probably need more so we randomly select.
while total_length < seq_length:
if self.disable_packing and sample_idx > 0:
break
if not sample_queue:
dataset_idx = torch.randint(len(self.dataset.datasets), (1,), generator=self.generator).item()
element_idx = torch.randint(len(self.dataset.datasets[dataset_idx]), (1,), generator=self.generator).item()
sample = self.dataset[(dataset_idx, element_idx)]
sample = process_batch(sample)
else:
sample = sample_queue.pop(0)
available_length = seq_length - total_length
if available_length < sample.shape[0] // 4:
if total_length > 0:
break
else:
continue
if "sample_ids" not in sample:
sequence_starts = (sample['input_ids'] == self.padding_token_id).long()
sample["sample_ids"] = torch.cumsum(sequence_starts, dim=0) - 1
processed_ids = torch.where(sample["sample_ids"] < 0, torch.zeros_like(sample["sample_ids"]), -1)
sample["sample_ids"] = processed_ids
if not ((sample["sample_ids"] == 0) | (sample["sample_ids"] == -1)).all():
assert (sample["modality"] == 0).all()
first_neg_one = (sample["sample_ids"] == -1).nonzero(as_tuple=True)[0]
if first_neg_one.numel() > 0:
first_neg_one = first_neg_one[0].item()
else:
assert sample["attention_mask"].all()
first_neg_one = len(sample["attention_mask"])
valid_slice = slice(None, min(first_neg_one, available_length))
new_length = min(first_neg_one, available_length)
sample["sample_ids"][valid_slice] = sample_idx
new_batch[i, total_length:total_length+new_length] = sample[valid_slice]
total_length += new_length
sample_idx += 1
if (new_batch["sample_ids"] == -1).all():
gprint(f"WARNING!!!! All sample ids are -1 in packing collate before ignore")
if new_batch["modality"][i, -1] == 1:
# Find contiguous sequence of image tokens from the end
modality_slice = new_batch["modality"][i]
is_image = modality_slice == 1
# Get indices where modality changes
change_points = torch.where(is_image[:-1] != is_image[1:])[0] + 1
if change_points.numel() > 0 and is_image[-1]:
# Get start of last contiguous image sequence
start_pos = change_points[-1].item()
assert (new_batch["modality"][i, start_pos:] == 1).all()
try:
if start_pos > 0 and new_batch["input_ids"][i, start_pos - 1] == self.image_token_id:
start_pos -= 1
if start_pos > 0 and new_batch["input_ids"][i, start_pos - 1] != self.eos_token_id:
new_batch["input_ids"][i, start_pos] = self.eos_token_id
new_batch["attention_mask"][i, start_pos] = 1
new_batch["modality"][i, start_pos] = 0
start_pos += 1
except IndexError:
print(f"WARNING!!!! ERROR IN PACKING COLLATE")
ignore_slice(new_batch[i], slice(start_pos, None), self.padding_token_id)
if (new_batch["sample_ids"] == -1).all():
gprint(f"WARNING!!!! All sample ids are -1 in packing collate after ignore")
return new_batch