import math import typing from pathlib import Path import tokenizers import torch import transformers from unidisc.datasets.sampler import WeightedDatasetSampler from models.datasets.image_datasets import TensorCollate, get_image_dataset, get_unpaired_dataset from models.datasets.text_datasets import Text8Tokenizer, get_text_dataset from torch.utils.data import default_collate from decoupled_utils import breakpoint_on_error, gprint, rprint, is_torch_xla_available from datasets import load_dataset def identity(x): return x def get_dataset(dataset_name, tokenizer, *args, config=None, **kwargs): rprint(f"getting dataset {dataset_name}") if getattr(config.data, "unpaired", False): return get_unpaired_dataset(config=config, tokenizer=tokenizer, **kwargs) elif getattr(config.model, "image_model", False) or getattr(config.data, "force_image_dataset", False): return get_image_dataset(config=config, tokenizer=tokenizer, **kwargs) else: rprint(f"getting text dataset") return get_text_dataset(dataset_name, tokenizer, *args, **kwargs) def tokenize_text(tokenizer, block_size, text, return_token_type_ids=True): return tokenizer(text, max_length=block_size, padding="max_length", truncation=True, add_special_tokens=True, return_attention_mask=True, return_token_type_ids=return_token_type_ids).convert_to_tensors("pt") def get_tokenizer(config): if config.data.tokenizer_name_or_path is None or config.data.tokenizer_name_or_path == "None": return None elif config.data.tokenizer_name_or_path == "text8": tokenizer = Text8Tokenizer() elif config.data.tokenizer_name_or_path == "bert-base-uncased": tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") else: tokenizer_kwargs = dict() if config.data.tokenizer_name_or_path == "NousResearch/Llama-2-7b-hf": tokenizer_kwargs["add_eos_token"] = True tokenizer_kwargs["padding_side"] = 'right' rprint("Using Llama tokenizer, adding add_eos_token and setting padding_side to right") if getattr(config.data, "use_slow_tokenizer", False): tokenizer_kwargs["use_fast"] = False tokenizer = transformers.AutoTokenizer.from_pretrained(config.data.tokenizer_name_or_path, **tokenizer_kwargs) if getattr(config.data, "add_image_token", False): special_token = '' existing_id = 811 tmp_index = len(tokenizer) tokenizer.add_special_tokens({ 'additional_special_tokens': [special_token] }, replace_additional_special_tokens=False) tokenizer._added_tokens_decoder[existing_id] = tokenizer._added_tokens_decoder.pop(tmp_index) assert len(tokenizer.additional_special_tokens_ids) == 1 tokenizer.additional_special_tokens_ids = [existing_id] tokenizer._added_tokens_encoder[''] = existing_id tokenizer.total_vocab_size = tmp_index if isinstance(tokenizer, transformers.GPT2TokenizerFast) or isinstance(tokenizer, transformers.GPT2Tokenizer): tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing( (tokenizer.bos_token, tokenizer.bos_token_id), (tokenizer.eos_token, tokenizer.eos_token_id) ) # For wrapped batches: # [BOS] sent1 [EOS] sent2-fragment [EOS] # [BOS] sent2-fragment [EOS] sent3 [EOS] if tokenizer.bos_token is None: if tokenizer.cls_token is None: raise AttributeError("Tokenizer must have a bos_token or " f"cls_token: {tokenizer}") tokenizer.bos_token = tokenizer.cls_token if tokenizer.eos_token is None: if tokenizer.sep_token is None: raise AttributeError("Tokenizer must have a eos_token " f"or sep_token: {tokenizer}") tokenizer.eos_token = tokenizer.sep_token if tokenizer.pad_token is None: if config.data.tokenizer_name_or_path != "gpt2": rprint(f"Adding pad token to tokenizer") tokenizer.add_special_tokens({"pad_token": "[PAD]"}) assert tokenizer.padding_side == 'right' assert tokenizer.truncation_side == 'right' return tokenizer class SimpleDataLoader: def __init__(self, dataset, batch_size=1, collate_fn=default_collate, **kwargs): self.dataset = dataset self.batch_size = batch_size self.collate_fn = collate_fn self.idx = 0 def __iter__(self): return self def __next__(self): if self.idx < len(self.dataset): batch = [] for _ in range(self.batch_size): if self.idx >= len(self.dataset): break batch.append(self.dataset[self.idx]) self.idx += 1 return self.collate_fn(batch) else: raise StopIteration def __len__(self): return (len(self.dataset) + self.batch_size - 1) // self.batch_size def get_zero_shot_dataloader(config, tokenizer, device=None, **kwargs): if config.data.zero_shot_eval_dataset is None: rprint("No zero shot eval dataset provided") return None, None dataset_name = config.data.zero_shot_eval_dataset dataloader_seed = config.seed if config.mode == "eval" else 42 if dataset_name == "nlphuji/flickr30k": data = load_dataset(dataset_name, num_proc=config.data.num_proc, cache_dir=config.data.cache_dir, streaming=config.data.streaming) dataset = data["test"] elif dataset_name == "facebook/winoground": data = load_dataset(dataset_name, num_proc=config.data.num_proc, cache_dir=config.data.cache_dir, streaming=config.data.streaming) dataset = data["test"] breakpoint() dl_cls = torch.utils.data.DataLoader valid_loader = dl_cls( dataset, batch_size=config.loader.eval_batch_size, num_workers=config.loader.num_eval_workers, pin_memory=config.loader.pin_memory, generator=torch.Generator().manual_seed(dataloader_seed), persistent_workers=False, **kwargs, ) valid_loader.tokenizer = tokenizer return valid_loader def get_dataloaders(config, tokenizer, skip_train=False, skip_valid=False, valid_seed=None, device=None, **kwargs): if skip_train: train_set = None else: _mode = getattr(config.data, "force_train_mode", "train") if _mode != "train": rprint(f"Forcing train mode to {_mode}") train_set = get_dataset( config.data.train, tokenizer, mode=_mode, wrap=config.data.wrap, cache_dir=config.data.cache_dir, block_size=config.model.length, num_proc=config.data.num_proc, streaming=config.data.streaming, config=config, **kwargs, ) if hasattr(train_set, '__len__'): rprint(f"Training set len: {len(train_set)}") if config.data.valid in ["text8", "lm1b", "ag_news"]: validation_split = "test" else: validation_split = "validation" if skip_valid: valid_set = None else: valid_set = get_dataset( config.data.valid, tokenizer, wrap=config.data.wrap, mode=validation_split, cache_dir=config.data.cache_dir, block_size=config.model.length, streaming=False, num_proc=config.data.num_proc, config=config, **kwargs, ) if hasattr(valid_set, '__len__'): rprint(f"Validation set len: {len(valid_set)}") dataloader_seed = config.seed if (config.mode == "eval" or is_torch_xla_available() or getattr(config.data, "force_seed", False)) else 42 gprint(f"Dataloader seed: {dataloader_seed}") if skip_train: train_loader = None else: train_kwargs = dict(drop_last=True) train_dataloader_generator = torch.Generator().manual_seed(dataloader_seed) dl_cls = torch.utils.data.DataLoader if getattr(config.data, "webdataset_iterable", False) or getattr(config.data, "webdataset_indexed", False): train_kwargs.pop("drop_last", None) if getattr(config.loader, "disable_prefetch", False): train_kwargs["prefetch_factor"] = 1 if getattr(config.data, "force_disable_shuffle", False) is False: if getattr(config.data, "webdataset_iterable", False): import webdataset dl_cls = webdataset.WebLoader train_kwargs["shuffle"] = False train_kwargs["prefetch_factor"] = 8 elif getattr(config.data, "webdataset_indexed", False): import wids train_kwargs["sampler"] = wids.DistributedChunkedSampler(train_set, shuffle=True) elif isinstance(train_set, torch.utils.data.IterableDataset) is False: train_kwargs["shuffle"] = True if "tokens" in config.data.train and config.data.pin_dataset_to_gpu: if config.backend == 'cuda': cur_mb = torch.cuda.memory_reserved() / 1e9 rprint(f"Moving dataloader to device {device} with: {cur_mb} GB of memory reserved") train_set = train_set.to(device=device) if config.backend == 'cuda': cur_mb = torch.cuda.memory_reserved() / 1e9 rprint(f"Moved dataloader to device {device} with: {cur_mb} GB of memory reserved") if "tokens" in config.data.train: if getattr(config.data, "use_custom_tensordict_collate", False): train_kwargs["collate_fn"] = TensorCollate(device=device, enable_cuda_in_tensordict_collate=config.data.enable_cuda_in_tensordict_collate) else: train_kwargs["collate_fn"] = identity if getattr(config.data, "use_packing_collate", False): generator = torch.Generator().manual_seed(dataloader_seed) token_collate = train_kwargs["collate_fn"] if getattr(config.data, "use_custom_tensordict_collate", False) else None train_kwargs["collate_fn"] = PackingCollate(config, train_set, config.model.length, generator, tensor_collate=token_collate, tokenizer=tokenizer) if getattr(config.data, "use_weighted_tensordict_sampler", False): generator = torch.Generator().manual_seed(dataloader_seed) train_kwargs['sampler'] = WeightedDatasetSampler(train_set, generator=generator) train_kwargs["shuffle"] = False else: train_kwargs["shuffle"] = True if getattr(config.data, "use_list_collate", False): train_kwargs["collate_fn"] = lambda x: x if getattr(config.data, "force_shuffle_train", False): rprint("Forcing shuffle on train dataloader") train_kwargs["shuffle"] = True if getattr(config.data, "force_disable_shuffle_train", False): rprint("Forcing disable shuffle on train dataloader") train_kwargs["shuffle"] = False if getattr(config.data, "force_distributed_sampler", False): import torch_xla.runtime as xr train_kwargs["sampler"] = torch.utils.data.distributed.DistributedSampler( train_set, num_replicas=xr.world_size(), rank=xr.global_ordinal(), shuffle=True ) if getattr(config.data, "use_identity_collate", False): train_kwargs["collate_fn"] = lambda x: x if train_set.__class__.__name__ == "WebLoader": train_loader = train_set else: rprint(f"Train dataloader kwargs: {train_kwargs}") train_loader = dl_cls( train_set, batch_size=None if getattr(config.data, "webdataset_iterable", False) else config.loader.batch_size, num_workers=config.loader.num_workers, pin_memory=config.loader.pin_memory, persistent_workers=config.loader.num_workers > 0 and getattr(config.loader, "persistent_workers", True), generator=train_dataloader_generator, **train_kwargs, ) train_loader.tokenizer = tokenizer if skip_valid: valid_loader = None else: shuffle_valid = True valid_dataloader_generator = torch.Generator().manual_seed(dataloader_seed) valid_kwargs = dict(drop_last=True) dl_cls = torch.utils.data.DataLoader if getattr(config.data, "webdataset_iterable", False) or getattr(config.data, "webdataset_indexed", False): valid_kwargs.pop("drop_last", None) if getattr(config.data, "force_disable_shuffle", False) is False: if getattr(config.data, "webdataset_iterable", False): valid_kwargs["shuffle"] = False import webdataset dl_cls = webdataset.WebLoader elif getattr(config.data, "webdataset_indexed", False): import wids valid_kwargs["sampler"] = wids.DistributedChunkedSampler(valid_set, shuffle=True) elif isinstance(valid_set, torch.utils.data.IterableDataset) is False and shuffle_valid: valid_kwargs["shuffle"] = shuffle_valid if "tokens" in config.data.valid: if getattr(config.data, "use_custom_tensordict_collate", False): valid_kwargs["collate_fn"] = TensorCollate(device=device, enable_cuda_in_tensordict_collate=config.data.enable_cuda_in_tensordict_collate) else: valid_kwargs["collate_fn"] = identity if getattr(config.data, "use_packing_collate", False): generator = torch.Generator().manual_seed(dataloader_seed) token_collate = valid_kwargs["collate_fn"] if getattr(config.data, "use_custom_tensordict_collate", False) else None valid_kwargs["collate_fn"] = PackingCollate(config, valid_set, config.model.length, generator, tensor_collate=token_collate, tokenizer=tokenizer) if getattr(config.data, "use_weighted_tensordict_sampler", False): generator = torch.Generator().manual_seed(dataloader_seed) valid_kwargs['sampler'] = WeightedDatasetSampler(valid_set, generator=generator) if getattr(config.data, "shuffle_valid", False): torch.manual_seed(config.seed) valid_kwargs["shuffle"] = getattr(config.data, "shuffle_valid", False) if getattr(config.data, "force_distributed_sampler", False): import torch_xla.runtime as xr valid_kwargs["sampler"] = torch.utils.data.distributed.DistributedSampler( valid_set, num_replicas=xr.world_size(), rank=xr.global_ordinal(), shuffle=True ) if valid_set.__class__.__name__ == "WebLoader": valid_loader = valid_set else: rprint(f"Valid dataloader kwargs: {valid_kwargs}") valid_loader = dl_cls( valid_set, batch_size=None if getattr(config.data, "webdataset_iterable", False) else config.loader.eval_batch_size, num_workers=getattr(config.loader, "num_eval_workers", config.loader.num_workers), pin_memory=config.loader.pin_memory, generator=valid_dataloader_generator, persistent_workers=False, **valid_kwargs, ) # Will be used in generative perplexity calculation valid_loader.tokenizer = tokenizer return train_loader, valid_loader # Samplers adapted from: https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/fault_tolerant_sampler.py class RandomFaultTolerantSampler(torch.utils.data.RandomSampler): def __init__(self, *args, generator=None, **kwargs): # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, # which should be reproducible if pl.seed_everything was called beforehand. # This means that changing the seed of the experiment will also change the # sampling order. if generator is None: seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator().manual_seed(seed) kwargs.pop("shuffle", None) super().__init__(*args, generator=generator, **kwargs) self.counter = 0 self.restarting = False def state_dict(self): return {"random_state": self.generator.get_state(), "counter": self.counter} def load_state_dict(self, state_dict): self.generator.set_state(state_dict.get("random_state")) self.counter = state_dict["counter"] # self.start_counter = self.counter self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. def __iter__(self) -> typing.Iterator[int]: n = len(self.data_source) self.state = self.generator.get_state() indices = torch.randperm(n, generator=self.generator).tolist() if not self.restarting: self.counter = 0 else: indices = indices[self.counter :] self.restarting = False for index in indices: self.counter += 1 yield index self.counter = 0 class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.counter = 0 self.restarting = False def state_dict(self): return {"epoch": self.epoch, "counter": self.counter} def load_state_dict(self, state_dict): self.epoch = state_dict["epoch"] self.counter = state_dict["counter"] self.restarting = True # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per # epoch, and subsequent epoch will have very few batches. def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples if not self.restarting: self.counter = 0 else: indices = indices[self.counter :] self.restarting = False for index in indices: self.counter += 1 yield index self.counter = 0 if __name__ == "__main__": import os with breakpoint_on_error(): from omegaconf import OmegaConf cc12m_config = OmegaConf.create( { "model": { "image_model": True, "unified_model": True, }, "data": { "tokenizers_parallelism": False, "resolution": 128, "train": "pixparse/cc12m-wds", "val": "pixparse/cc12m-wds", "streaming": False, "precache": True, "tokenizer_name_or_path": "gpt2", "n_val_samples": None, "n_train_samples": None, "block_size": 32, "data_dir": "/path/to/cc12m", }, } ) imagenet_config = OmegaConf.create( { "model": { "image_model": True, }, "data": { "resolution": 128, "train": "ILSVRC/imagenet-1k", "val": "ILSVRC/imagenet-1k", "streaming": False, "precache": True, "tokenizer_name_or_path": "gpt2", }, } ) facecaption_config = OmegaConf.create( { "seed": 12345, "model": { "image_model": True, }, "data": { "resolution": 256, "train": "facecaption", "val": "facecaption", "streaming": False, "precache": False, "tokenizer_name_or_path": "gpt2", "cache_dir": os.environ["HF_DATASETS_CACHE"], "raw_data_dir": "/grogu/user/mprabhud/data/diffusion/facecaption", "block_size": 32, }, "loader": { "num_workers": 0, "batch_size": 1, "eval_batch_size": 1, }, "trainer": { "devices": 1, "num_nodes": 1, "accumulate_grad_batches": 1, }, } ) tokenizer = get_tokenizer(facecaption_config) dataset = get_dataset( dataset_name=facecaption_config.data.train, mode="train", config=facecaption_config, tokenizer=tokenizer, ) test = next(iter(dataset)) breakpoint() from typing import List, Dict import torch from tensordict import TensorDict def process_batch(batch: TensorDict): if isinstance(batch, list): return [process_batch(b) for b in batch] else: if "write_flag" in batch: del batch["write_flag"] if "dataset_idx" in batch: del batch["dataset_idx"] batch.auto_batch_size_() return batch def ignore_slice(tensor, slice, padding_token_id): tensor["modality"][slice] = -1 tensor["attention_mask"][slice] = 0 tensor["input_ids"][slice] = padding_token_id if "sample_ids" in tensor: tensor["sample_ids"][slice] = -1 else: tensor["sample_ids"] = torch.full(tensor["input_ids"].shape, fill_value=-1, dtype=tensor["input_ids"].dtype, device=tensor["input_ids"].device) class PackingCollate: def __init__(self, config, dataset, seq_length, generator, tensor_collate=None, tokenizer=None): self.dataset = dataset self.seq_length = seq_length self.tensor_collate = tensor_collate self.generator = generator self.tokenizer = tokenizer self.padding_token_id = tokenizer.pad_token_id self.eos_token_id = tokenizer.eos_token_id self.disable_packing = getattr(config.data, "disable_packing", False) img_special_tokens = tokenizer("", add_special_tokens=False)['input_ids'] assert len(img_special_tokens) == 1 self.image_token_id = img_special_tokens[0] def __call__(self, batch: TensorDict): if self.tensor_collate is not None: if isinstance(batch, list): batch = [self.tensor_collate(b) for b in batch] else: batch = self.tensor_collate(batch) B = len(batch) seq_length = self.seq_length batch = process_batch(batch) assert batch[0].batch_size is None or len(batch[0].batch_size) == 1 new_batch = batch[0].new_zeros((B, seq_length)) ignore_slice(new_batch, slice(None, None), self.padding_token_id) for i in range(B): total_length = 0 sample_idx = 0 sample_queue = [batch[i]] # We originally get bs number of samples but since we're packing, we probably need more so we randomly select. while total_length < seq_length: if self.disable_packing and sample_idx > 0: break if not sample_queue: dataset_idx = torch.randint(len(self.dataset.datasets), (1,), generator=self.generator).item() element_idx = torch.randint(len(self.dataset.datasets[dataset_idx]), (1,), generator=self.generator).item() sample = self.dataset[(dataset_idx, element_idx)] sample = process_batch(sample) else: sample = sample_queue.pop(0) available_length = seq_length - total_length if available_length < sample.shape[0] // 4: if total_length > 0: break else: continue if "sample_ids" not in sample: sequence_starts = (sample['input_ids'] == self.padding_token_id).long() sample["sample_ids"] = torch.cumsum(sequence_starts, dim=0) - 1 processed_ids = torch.where(sample["sample_ids"] < 0, torch.zeros_like(sample["sample_ids"]), -1) sample["sample_ids"] = processed_ids if not ((sample["sample_ids"] == 0) | (sample["sample_ids"] == -1)).all(): assert (sample["modality"] == 0).all() first_neg_one = (sample["sample_ids"] == -1).nonzero(as_tuple=True)[0] if first_neg_one.numel() > 0: first_neg_one = first_neg_one[0].item() else: assert sample["attention_mask"].all() first_neg_one = len(sample["attention_mask"]) valid_slice = slice(None, min(first_neg_one, available_length)) new_length = min(first_neg_one, available_length) sample["sample_ids"][valid_slice] = sample_idx new_batch[i, total_length:total_length+new_length] = sample[valid_slice] total_length += new_length sample_idx += 1 if (new_batch["sample_ids"] == -1).all(): gprint(f"WARNING!!!! All sample ids are -1 in packing collate before ignore") if new_batch["modality"][i, -1] == 1: # Find contiguous sequence of image tokens from the end modality_slice = new_batch["modality"][i] is_image = modality_slice == 1 # Get indices where modality changes change_points = torch.where(is_image[:-1] != is_image[1:])[0] + 1 if change_points.numel() > 0 and is_image[-1]: # Get start of last contiguous image sequence start_pos = change_points[-1].item() assert (new_batch["modality"][i, start_pos:] == 1).all() try: if start_pos > 0 and new_batch["input_ids"][i, start_pos - 1] == self.image_token_id: start_pos -= 1 if start_pos > 0 and new_batch["input_ids"][i, start_pos - 1] != self.eos_token_id: new_batch["input_ids"][i, start_pos] = self.eos_token_id new_batch["attention_mask"][i, start_pos] = 1 new_batch["modality"][i, start_pos] = 0 start_pos += 1 except IndexError: print(f"WARNING!!!! ERROR IN PACKING COLLATE") ignore_slice(new_batch[i], slice(start_pos, None), self.padding_token_id) if (new_batch["sample_ids"] == -1).all(): gprint(f"WARNING!!!! All sample ids are -1 in packing collate after ignore") return new_batch