|
import math |
|
import typing |
|
from pathlib import Path |
|
|
|
import tokenizers |
|
import torch |
|
import transformers |
|
from unidisc.datasets.sampler import WeightedDatasetSampler |
|
|
|
from models.datasets.image_datasets import TensorCollate, get_image_dataset, get_unpaired_dataset |
|
from models.datasets.text_datasets import Text8Tokenizer, get_text_dataset |
|
from torch.utils.data import default_collate |
|
from decoupled_utils import breakpoint_on_error, gprint, rprint, is_torch_xla_available |
|
from datasets import load_dataset |
|
|
|
|
|
def identity(x): |
|
return x |
|
|
|
|
|
def get_dataset(dataset_name, tokenizer, *args, config=None, **kwargs): |
|
rprint(f"getting dataset {dataset_name}") |
|
if getattr(config.data, "unpaired", False): |
|
return get_unpaired_dataset(config=config, tokenizer=tokenizer, **kwargs) |
|
elif getattr(config.model, "image_model", False) or getattr(config.data, "force_image_dataset", False): |
|
return get_image_dataset(config=config, tokenizer=tokenizer, **kwargs) |
|
else: |
|
rprint(f"getting text dataset") |
|
return get_text_dataset(dataset_name, tokenizer, *args, **kwargs) |
|
|
|
def tokenize_text(tokenizer, block_size, text, return_token_type_ids=True): |
|
return tokenizer(text, max_length=block_size, padding="max_length", truncation=True, add_special_tokens=True, return_attention_mask=True, return_token_type_ids=return_token_type_ids).convert_to_tensors("pt") |
|
|
|
def get_tokenizer(config): |
|
if config.data.tokenizer_name_or_path is None or config.data.tokenizer_name_or_path == "None": |
|
return None |
|
elif config.data.tokenizer_name_or_path == "text8": |
|
tokenizer = Text8Tokenizer() |
|
elif config.data.tokenizer_name_or_path == "bert-base-uncased": |
|
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased") |
|
else: |
|
tokenizer_kwargs = dict() |
|
if config.data.tokenizer_name_or_path == "NousResearch/Llama-2-7b-hf": |
|
tokenizer_kwargs["add_eos_token"] = True |
|
tokenizer_kwargs["padding_side"] = 'right' |
|
rprint("Using Llama tokenizer, adding add_eos_token and setting padding_side to right") |
|
if getattr(config.data, "use_slow_tokenizer", False): |
|
tokenizer_kwargs["use_fast"] = False |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(config.data.tokenizer_name_or_path, **tokenizer_kwargs) |
|
|
|
if getattr(config.data, "add_image_token", False): |
|
special_token = '<image>' |
|
existing_id = 811 |
|
tmp_index = len(tokenizer) |
|
tokenizer.add_special_tokens({ |
|
'additional_special_tokens': [special_token] |
|
}, replace_additional_special_tokens=False) |
|
tokenizer._added_tokens_decoder[existing_id] = tokenizer._added_tokens_decoder.pop(tmp_index) |
|
assert len(tokenizer.additional_special_tokens_ids) == 1 |
|
tokenizer.additional_special_tokens_ids = [existing_id] |
|
tokenizer._added_tokens_encoder['<image>'] = existing_id |
|
tokenizer.total_vocab_size = tmp_index |
|
|
|
if isinstance(tokenizer, transformers.GPT2TokenizerFast) or isinstance(tokenizer, transformers.GPT2Tokenizer): |
|
tokenizer._tokenizer.post_processor = tokenizers.processors.BertProcessing( |
|
(tokenizer.bos_token, tokenizer.bos_token_id), (tokenizer.eos_token, tokenizer.eos_token_id) |
|
) |
|
|
|
|
|
|
|
|
|
if tokenizer.bos_token is None: |
|
if tokenizer.cls_token is None: |
|
raise AttributeError("Tokenizer must have a bos_token or " f"cls_token: {tokenizer}") |
|
tokenizer.bos_token = tokenizer.cls_token |
|
if tokenizer.eos_token is None: |
|
if tokenizer.sep_token is None: |
|
raise AttributeError("Tokenizer must have a eos_token " f"or sep_token: {tokenizer}") |
|
tokenizer.eos_token = tokenizer.sep_token |
|
if tokenizer.pad_token is None: |
|
if config.data.tokenizer_name_or_path != "gpt2": |
|
rprint(f"Adding pad token to tokenizer") |
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
|
|
|
assert tokenizer.padding_side == 'right' |
|
assert tokenizer.truncation_side == 'right' |
|
|
|
return tokenizer |
|
|
|
|
|
class SimpleDataLoader: |
|
def __init__(self, dataset, batch_size=1, collate_fn=default_collate, **kwargs): |
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
self.collate_fn = collate_fn |
|
self.idx = 0 |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
if self.idx < len(self.dataset): |
|
batch = [] |
|
for _ in range(self.batch_size): |
|
if self.idx >= len(self.dataset): |
|
break |
|
batch.append(self.dataset[self.idx]) |
|
self.idx += 1 |
|
return self.collate_fn(batch) |
|
else: |
|
raise StopIteration |
|
|
|
def __len__(self): |
|
return (len(self.dataset) + self.batch_size - 1) // self.batch_size |
|
|
|
def get_zero_shot_dataloader(config, tokenizer, device=None, **kwargs): |
|
if config.data.zero_shot_eval_dataset is None: |
|
rprint("No zero shot eval dataset provided") |
|
return None, None |
|
|
|
dataset_name = config.data.zero_shot_eval_dataset |
|
dataloader_seed = config.seed if config.mode == "eval" else 42 |
|
if dataset_name == "nlphuji/flickr30k": |
|
data = load_dataset(dataset_name, num_proc=config.data.num_proc, cache_dir=config.data.cache_dir, streaming=config.data.streaming) |
|
dataset = data["test"] |
|
elif dataset_name == "facebook/winoground": |
|
data = load_dataset(dataset_name, num_proc=config.data.num_proc, cache_dir=config.data.cache_dir, streaming=config.data.streaming) |
|
dataset = data["test"] |
|
breakpoint() |
|
dl_cls = torch.utils.data.DataLoader |
|
valid_loader = dl_cls( |
|
dataset, |
|
batch_size=config.loader.eval_batch_size, |
|
num_workers=config.loader.num_eval_workers, |
|
pin_memory=config.loader.pin_memory, |
|
generator=torch.Generator().manual_seed(dataloader_seed), |
|
persistent_workers=False, |
|
**kwargs, |
|
) |
|
valid_loader.tokenizer = tokenizer |
|
return valid_loader |
|
|
|
|
|
def get_dataloaders(config, tokenizer, skip_train=False, skip_valid=False, valid_seed=None, device=None, **kwargs): |
|
if skip_train: |
|
train_set = None |
|
else: |
|
_mode = getattr(config.data, "force_train_mode", "train") |
|
if _mode != "train": |
|
rprint(f"Forcing train mode to {_mode}") |
|
train_set = get_dataset( |
|
config.data.train, |
|
tokenizer, |
|
mode=_mode, |
|
wrap=config.data.wrap, |
|
cache_dir=config.data.cache_dir, |
|
block_size=config.model.length, |
|
num_proc=config.data.num_proc, |
|
streaming=config.data.streaming, |
|
config=config, |
|
**kwargs, |
|
) |
|
if hasattr(train_set, '__len__'): |
|
rprint(f"Training set len: {len(train_set)}") |
|
|
|
if config.data.valid in ["text8", "lm1b", "ag_news"]: |
|
validation_split = "test" |
|
else: |
|
validation_split = "validation" |
|
|
|
if skip_valid: |
|
valid_set = None |
|
else: |
|
valid_set = get_dataset( |
|
config.data.valid, |
|
tokenizer, |
|
wrap=config.data.wrap, |
|
mode=validation_split, |
|
cache_dir=config.data.cache_dir, |
|
block_size=config.model.length, |
|
streaming=False, |
|
num_proc=config.data.num_proc, |
|
config=config, |
|
**kwargs, |
|
) |
|
if hasattr(valid_set, '__len__'): |
|
rprint(f"Validation set len: {len(valid_set)}") |
|
|
|
dataloader_seed = config.seed if (config.mode == "eval" or is_torch_xla_available() or getattr(config.data, "force_seed", False)) else 42 |
|
gprint(f"Dataloader seed: {dataloader_seed}") |
|
|
|
if skip_train: |
|
train_loader = None |
|
else: |
|
train_kwargs = dict(drop_last=True) |
|
train_dataloader_generator = torch.Generator().manual_seed(dataloader_seed) |
|
dl_cls = torch.utils.data.DataLoader |
|
if getattr(config.data, "webdataset_iterable", False) or getattr(config.data, "webdataset_indexed", False): |
|
train_kwargs.pop("drop_last", None) |
|
|
|
if getattr(config.loader, "disable_prefetch", False): |
|
train_kwargs["prefetch_factor"] = 1 |
|
|
|
if getattr(config.data, "force_disable_shuffle", False) is False: |
|
if getattr(config.data, "webdataset_iterable", False): |
|
import webdataset |
|
dl_cls = webdataset.WebLoader |
|
train_kwargs["shuffle"] = False |
|
train_kwargs["prefetch_factor"] = 8 |
|
elif getattr(config.data, "webdataset_indexed", False): |
|
import wids |
|
train_kwargs["sampler"] = wids.DistributedChunkedSampler(train_set, shuffle=True) |
|
elif isinstance(train_set, torch.utils.data.IterableDataset) is False: |
|
train_kwargs["shuffle"] = True |
|
|
|
if "tokens" in config.data.train and config.data.pin_dataset_to_gpu: |
|
if config.backend == 'cuda': |
|
cur_mb = torch.cuda.memory_reserved() / 1e9 |
|
rprint(f"Moving dataloader to device {device} with: {cur_mb} GB of memory reserved") |
|
train_set = train_set.to(device=device) |
|
if config.backend == 'cuda': |
|
cur_mb = torch.cuda.memory_reserved() / 1e9 |
|
rprint(f"Moved dataloader to device {device} with: {cur_mb} GB of memory reserved") |
|
|
|
if "tokens" in config.data.train: |
|
if getattr(config.data, "use_custom_tensordict_collate", False): |
|
train_kwargs["collate_fn"] = TensorCollate(device=device, enable_cuda_in_tensordict_collate=config.data.enable_cuda_in_tensordict_collate) |
|
else: |
|
train_kwargs["collate_fn"] = identity |
|
|
|
if getattr(config.data, "use_packing_collate", False): |
|
generator = torch.Generator().manual_seed(dataloader_seed) |
|
token_collate = train_kwargs["collate_fn"] if getattr(config.data, "use_custom_tensordict_collate", False) else None |
|
train_kwargs["collate_fn"] = PackingCollate(config, train_set, config.model.length, generator, tensor_collate=token_collate, tokenizer=tokenizer) |
|
|
|
if getattr(config.data, "use_weighted_tensordict_sampler", False): |
|
generator = torch.Generator().manual_seed(dataloader_seed) |
|
train_kwargs['sampler'] = WeightedDatasetSampler(train_set, generator=generator) |
|
train_kwargs["shuffle"] = False |
|
else: |
|
train_kwargs["shuffle"] = True |
|
|
|
if getattr(config.data, "use_list_collate", False): |
|
train_kwargs["collate_fn"] = lambda x: x |
|
|
|
if getattr(config.data, "force_shuffle_train", False): |
|
rprint("Forcing shuffle on train dataloader") |
|
train_kwargs["shuffle"] = True |
|
|
|
if getattr(config.data, "force_disable_shuffle_train", False): |
|
rprint("Forcing disable shuffle on train dataloader") |
|
train_kwargs["shuffle"] = False |
|
|
|
if getattr(config.data, "force_distributed_sampler", False): |
|
import torch_xla.runtime as xr |
|
train_kwargs["sampler"] = torch.utils.data.distributed.DistributedSampler( |
|
train_set, |
|
num_replicas=xr.world_size(), |
|
rank=xr.global_ordinal(), |
|
shuffle=True |
|
) |
|
|
|
if getattr(config.data, "use_identity_collate", False): |
|
train_kwargs["collate_fn"] = lambda x: x |
|
|
|
if train_set.__class__.__name__ == "WebLoader": |
|
train_loader = train_set |
|
else: |
|
rprint(f"Train dataloader kwargs: {train_kwargs}") |
|
train_loader = dl_cls( |
|
train_set, |
|
batch_size=None if getattr(config.data, "webdataset_iterable", False) else config.loader.batch_size, |
|
num_workers=config.loader.num_workers, |
|
pin_memory=config.loader.pin_memory, |
|
persistent_workers=config.loader.num_workers > 0 and getattr(config.loader, "persistent_workers", True), |
|
generator=train_dataloader_generator, |
|
**train_kwargs, |
|
) |
|
train_loader.tokenizer = tokenizer |
|
|
|
if skip_valid: |
|
valid_loader = None |
|
else: |
|
shuffle_valid = True |
|
valid_dataloader_generator = torch.Generator().manual_seed(dataloader_seed) |
|
valid_kwargs = dict(drop_last=True) |
|
|
|
dl_cls = torch.utils.data.DataLoader |
|
if getattr(config.data, "webdataset_iterable", False) or getattr(config.data, "webdataset_indexed", False): |
|
valid_kwargs.pop("drop_last", None) |
|
|
|
if getattr(config.data, "force_disable_shuffle", False) is False: |
|
if getattr(config.data, "webdataset_iterable", False): |
|
valid_kwargs["shuffle"] = False |
|
import webdataset |
|
dl_cls = webdataset.WebLoader |
|
elif getattr(config.data, "webdataset_indexed", False): |
|
import wids |
|
valid_kwargs["sampler"] = wids.DistributedChunkedSampler(valid_set, shuffle=True) |
|
elif isinstance(valid_set, torch.utils.data.IterableDataset) is False and shuffle_valid: |
|
valid_kwargs["shuffle"] = shuffle_valid |
|
|
|
if "tokens" in config.data.valid: |
|
if getattr(config.data, "use_custom_tensordict_collate", False): |
|
valid_kwargs["collate_fn"] = TensorCollate(device=device, enable_cuda_in_tensordict_collate=config.data.enable_cuda_in_tensordict_collate) |
|
else: |
|
valid_kwargs["collate_fn"] = identity |
|
|
|
if getattr(config.data, "use_packing_collate", False): |
|
generator = torch.Generator().manual_seed(dataloader_seed) |
|
token_collate = valid_kwargs["collate_fn"] if getattr(config.data, "use_custom_tensordict_collate", False) else None |
|
valid_kwargs["collate_fn"] = PackingCollate(config, valid_set, config.model.length, generator, tensor_collate=token_collate, tokenizer=tokenizer) |
|
|
|
if getattr(config.data, "use_weighted_tensordict_sampler", False): |
|
generator = torch.Generator().manual_seed(dataloader_seed) |
|
valid_kwargs['sampler'] = WeightedDatasetSampler(valid_set, generator=generator) |
|
|
|
if getattr(config.data, "shuffle_valid", False): |
|
torch.manual_seed(config.seed) |
|
|
|
valid_kwargs["shuffle"] = getattr(config.data, "shuffle_valid", False) |
|
|
|
if getattr(config.data, "force_distributed_sampler", False): |
|
import torch_xla.runtime as xr |
|
valid_kwargs["sampler"] = torch.utils.data.distributed.DistributedSampler( |
|
valid_set, |
|
num_replicas=xr.world_size(), |
|
rank=xr.global_ordinal(), |
|
shuffle=True |
|
) |
|
|
|
if valid_set.__class__.__name__ == "WebLoader": |
|
valid_loader = valid_set |
|
else: |
|
rprint(f"Valid dataloader kwargs: {valid_kwargs}") |
|
valid_loader = dl_cls( |
|
valid_set, |
|
batch_size=None if getattr(config.data, "webdataset_iterable", False) else config.loader.eval_batch_size, |
|
num_workers=getattr(config.loader, "num_eval_workers", config.loader.num_workers), |
|
pin_memory=config.loader.pin_memory, |
|
generator=valid_dataloader_generator, |
|
persistent_workers=False, |
|
**valid_kwargs, |
|
) |
|
|
|
valid_loader.tokenizer = tokenizer |
|
|
|
return train_loader, valid_loader |
|
|
|
|
|
|
|
|
|
|
|
class RandomFaultTolerantSampler(torch.utils.data.RandomSampler): |
|
|
|
def __init__(self, *args, generator=None, **kwargs): |
|
|
|
|
|
|
|
|
|
if generator is None: |
|
seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
|
generator = torch.Generator().manual_seed(seed) |
|
kwargs.pop("shuffle", None) |
|
super().__init__(*args, generator=generator, **kwargs) |
|
self.counter = 0 |
|
self.restarting = False |
|
|
|
def state_dict(self): |
|
return {"random_state": self.generator.get_state(), "counter": self.counter} |
|
|
|
def load_state_dict(self, state_dict): |
|
self.generator.set_state(state_dict.get("random_state")) |
|
self.counter = state_dict["counter"] |
|
|
|
self.restarting = True |
|
|
|
|
|
|
|
|
|
def __iter__(self) -> typing.Iterator[int]: |
|
n = len(self.data_source) |
|
|
|
self.state = self.generator.get_state() |
|
indices = torch.randperm(n, generator=self.generator).tolist() |
|
|
|
if not self.restarting: |
|
self.counter = 0 |
|
else: |
|
indices = indices[self.counter :] |
|
self.restarting = False |
|
|
|
for index in indices: |
|
self.counter += 1 |
|
yield index |
|
|
|
self.counter = 0 |
|
|
|
|
|
class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler): |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.counter = 0 |
|
self.restarting = False |
|
|
|
def state_dict(self): |
|
return {"epoch": self.epoch, "counter": self.counter} |
|
|
|
def load_state_dict(self, state_dict): |
|
self.epoch = state_dict["epoch"] |
|
self.counter = state_dict["counter"] |
|
self.restarting = True |
|
|
|
|
|
|
|
def __iter__(self): |
|
if self.shuffle: |
|
|
|
g = torch.Generator() |
|
g.manual_seed(self.seed + self.epoch) |
|
indices = torch.randperm(len(self.dataset), generator=g).tolist() |
|
else: |
|
indices = list(range(len(self.dataset))) |
|
|
|
if not self.drop_last: |
|
|
|
padding_size = self.total_size - len(indices) |
|
if padding_size <= len(indices): |
|
indices += indices[:padding_size] |
|
else: |
|
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] |
|
else: |
|
|
|
indices = indices[: self.total_size] |
|
assert len(indices) == self.total_size |
|
|
|
|
|
indices = indices[self.rank : self.total_size : self.num_replicas] |
|
assert len(indices) == self.num_samples |
|
|
|
if not self.restarting: |
|
self.counter = 0 |
|
else: |
|
indices = indices[self.counter :] |
|
self.restarting = False |
|
|
|
for index in indices: |
|
self.counter += 1 |
|
yield index |
|
|
|
self.counter = 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
import os |
|
|
|
with breakpoint_on_error(): |
|
from omegaconf import OmegaConf |
|
|
|
cc12m_config = OmegaConf.create( |
|
{ |
|
"model": { |
|
"image_model": True, |
|
"unified_model": True, |
|
}, |
|
"data": { |
|
"tokenizers_parallelism": False, |
|
"resolution": 128, |
|
"train": "pixparse/cc12m-wds", |
|
"val": "pixparse/cc12m-wds", |
|
"streaming": False, |
|
"precache": True, |
|
"tokenizer_name_or_path": "gpt2", |
|
"n_val_samples": None, |
|
"n_train_samples": None, |
|
"block_size": 32, |
|
"data_dir": "/path/to/cc12m", |
|
}, |
|
} |
|
) |
|
|
|
imagenet_config = OmegaConf.create( |
|
{ |
|
"model": { |
|
"image_model": True, |
|
}, |
|
"data": { |
|
"resolution": 128, |
|
"train": "ILSVRC/imagenet-1k", |
|
"val": "ILSVRC/imagenet-1k", |
|
"streaming": False, |
|
"precache": True, |
|
"tokenizer_name_or_path": "gpt2", |
|
}, |
|
} |
|
) |
|
|
|
facecaption_config = OmegaConf.create( |
|
{ |
|
"seed": 12345, |
|
"model": { |
|
"image_model": True, |
|
}, |
|
"data": { |
|
"resolution": 256, |
|
"train": "facecaption", |
|
"val": "facecaption", |
|
"streaming": False, |
|
"precache": False, |
|
"tokenizer_name_or_path": "gpt2", |
|
"cache_dir": os.environ["HF_DATASETS_CACHE"], |
|
"raw_data_dir": "/grogu/user/mprabhud/data/diffusion/facecaption", |
|
"block_size": 32, |
|
}, |
|
"loader": { |
|
"num_workers": 0, |
|
"batch_size": 1, |
|
"eval_batch_size": 1, |
|
}, |
|
"trainer": { |
|
"devices": 1, |
|
"num_nodes": 1, |
|
"accumulate_grad_batches": 1, |
|
}, |
|
} |
|
) |
|
|
|
tokenizer = get_tokenizer(facecaption_config) |
|
dataset = get_dataset( |
|
dataset_name=facecaption_config.data.train, |
|
mode="train", |
|
config=facecaption_config, |
|
tokenizer=tokenizer, |
|
) |
|
test = next(iter(dataset)) |
|
breakpoint() |
|
|
|
|
|
|
|
from typing import List, Dict |
|
import torch |
|
from tensordict import TensorDict |
|
def process_batch(batch: TensorDict): |
|
if isinstance(batch, list): |
|
return [process_batch(b) for b in batch] |
|
else: |
|
if "write_flag" in batch: |
|
del batch["write_flag"] |
|
if "dataset_idx" in batch: |
|
del batch["dataset_idx"] |
|
batch.auto_batch_size_() |
|
return batch |
|
|
|
def ignore_slice(tensor, slice, padding_token_id): |
|
tensor["modality"][slice] = -1 |
|
tensor["attention_mask"][slice] = 0 |
|
tensor["input_ids"][slice] = padding_token_id |
|
if "sample_ids" in tensor: |
|
tensor["sample_ids"][slice] = -1 |
|
else: |
|
tensor["sample_ids"] = torch.full(tensor["input_ids"].shape, fill_value=-1, dtype=tensor["input_ids"].dtype, device=tensor["input_ids"].device) |
|
|
|
class PackingCollate: |
|
def __init__(self, config, dataset, seq_length, generator, tensor_collate=None, tokenizer=None): |
|
self.dataset = dataset |
|
self.seq_length = seq_length |
|
self.tensor_collate = tensor_collate |
|
self.generator = generator |
|
self.tokenizer = tokenizer |
|
self.padding_token_id = tokenizer.pad_token_id |
|
self.eos_token_id = tokenizer.eos_token_id |
|
self.disable_packing = getattr(config.data, "disable_packing", False) |
|
img_special_tokens = tokenizer("<image>", add_special_tokens=False)['input_ids'] |
|
assert len(img_special_tokens) == 1 |
|
self.image_token_id = img_special_tokens[0] |
|
|
|
def __call__(self, batch: TensorDict): |
|
if self.tensor_collate is not None: |
|
if isinstance(batch, list): |
|
batch = [self.tensor_collate(b) for b in batch] |
|
else: |
|
batch = self.tensor_collate(batch) |
|
|
|
B = len(batch) |
|
seq_length = self.seq_length |
|
|
|
batch = process_batch(batch) |
|
assert batch[0].batch_size is None or len(batch[0].batch_size) == 1 |
|
|
|
new_batch = batch[0].new_zeros((B, seq_length)) |
|
ignore_slice(new_batch, slice(None, None), self.padding_token_id) |
|
|
|
for i in range(B): |
|
total_length = 0 |
|
sample_idx = 0 |
|
sample_queue = [batch[i]] |
|
|
|
|
|
while total_length < seq_length: |
|
if self.disable_packing and sample_idx > 0: |
|
break |
|
if not sample_queue: |
|
dataset_idx = torch.randint(len(self.dataset.datasets), (1,), generator=self.generator).item() |
|
element_idx = torch.randint(len(self.dataset.datasets[dataset_idx]), (1,), generator=self.generator).item() |
|
sample = self.dataset[(dataset_idx, element_idx)] |
|
sample = process_batch(sample) |
|
else: |
|
sample = sample_queue.pop(0) |
|
|
|
available_length = seq_length - total_length |
|
if available_length < sample.shape[0] // 4: |
|
if total_length > 0: |
|
break |
|
else: |
|
continue |
|
|
|
if "sample_ids" not in sample: |
|
sequence_starts = (sample['input_ids'] == self.padding_token_id).long() |
|
sample["sample_ids"] = torch.cumsum(sequence_starts, dim=0) - 1 |
|
processed_ids = torch.where(sample["sample_ids"] < 0, torch.zeros_like(sample["sample_ids"]), -1) |
|
sample["sample_ids"] = processed_ids |
|
|
|
if not ((sample["sample_ids"] == 0) | (sample["sample_ids"] == -1)).all(): |
|
assert (sample["modality"] == 0).all() |
|
|
|
first_neg_one = (sample["sample_ids"] == -1).nonzero(as_tuple=True)[0] |
|
|
|
if first_neg_one.numel() > 0: |
|
first_neg_one = first_neg_one[0].item() |
|
else: |
|
assert sample["attention_mask"].all() |
|
first_neg_one = len(sample["attention_mask"]) |
|
|
|
valid_slice = slice(None, min(first_neg_one, available_length)) |
|
new_length = min(first_neg_one, available_length) |
|
|
|
sample["sample_ids"][valid_slice] = sample_idx |
|
new_batch[i, total_length:total_length+new_length] = sample[valid_slice] |
|
|
|
total_length += new_length |
|
sample_idx += 1 |
|
|
|
if (new_batch["sample_ids"] == -1).all(): |
|
gprint(f"WARNING!!!! All sample ids are -1 in packing collate before ignore") |
|
|
|
if new_batch["modality"][i, -1] == 1: |
|
|
|
modality_slice = new_batch["modality"][i] |
|
is_image = modality_slice == 1 |
|
|
|
|
|
change_points = torch.where(is_image[:-1] != is_image[1:])[0] + 1 |
|
|
|
if change_points.numel() > 0 and is_image[-1]: |
|
|
|
start_pos = change_points[-1].item() |
|
assert (new_batch["modality"][i, start_pos:] == 1).all() |
|
try: |
|
if start_pos > 0 and new_batch["input_ids"][i, start_pos - 1] == self.image_token_id: |
|
start_pos -= 1 |
|
|
|
if start_pos > 0 and new_batch["input_ids"][i, start_pos - 1] != self.eos_token_id: |
|
new_batch["input_ids"][i, start_pos] = self.eos_token_id |
|
new_batch["attention_mask"][i, start_pos] = 1 |
|
new_batch["modality"][i, start_pos] = 0 |
|
start_pos += 1 |
|
|
|
except IndexError: |
|
print(f"WARNING!!!! ERROR IN PACKING COLLATE") |
|
|
|
ignore_slice(new_batch[i], slice(start_pos, None), self.padding_token_id) |
|
|
|
if (new_batch["sample_ids"] == -1).all(): |
|
gprint(f"WARNING!!!! All sample ids are -1 in packing collate after ignore") |
|
|
|
return new_batch |
|
|
|
|