from __future__ import annotations from tensordict import tensorclass import torch from torch import nn from typing import Optional from unidisc.utils.tensor_utils import get_contiguous_blocks, get_interleaved_indices from tensordict import TensorDict @tensorclass class InterleavedBatch: input_ids: torch.Tensor modality: torch.Tensor sample_ids: torch.Tensor attention_mask: Optional[torch.Tensor] = None def to_ragged_batch(self): data = [] batch_indices, start_positions, end_positions = get_contiguous_blocks(self.sample_ids) first_sample_ids = self.sample_ids[batch_indices, start_positions] self.auto_batch_size_() for i in range(batch_indices.shape[0]): if first_sample_ids[i] == -1: continue data.append(self[batch_indices[i], start_positions[i]:end_positions[i]]) return TensorDict.lazy_stack(data, dim=0) def to_elements(self): data = self.to_ragged_batch() new_data = [] for i in range(data.shape[0]): new_data.append(InterleavedElement.from_raw(data[i])) return TensorDict.lazy_stack(new_data, dim=0) @classmethod def custom_from_dict(cls, data: TensorDict): new_dict = {} for field in cls.fields(): if field.name in data: new_dict[field.name] = data[field.name] return cls(**new_dict) @tensorclass class InterleavedElement: txt_input_ids: Optional[list[torch.Tensor]] = None img_input_ids: Optional[list[torch.Tensor]] = None txt: Optional[torch.Tensor] = None img: Optional[torch.Tensor] = None img_pos_ids: Optional[torch.Tensor] = None batch_indices: Optional[torch.Tensor] = None start_positions: Optional[torch.Tensor] = None end_positions: Optional[torch.Tensor] = None raw_data: Optional[InterleavedBatch] = None @classmethod def from_raw(cls, interleaved_batch: InterleavedBatch): batch_indices, start_positions, end_positions = get_contiguous_blocks(interleaved_batch.modality[None]) block_modality = interleaved_batch.modality[start_positions] img_input_ids = [] txt_input_ids = [] img_pos_ids = [] for i in range(batch_indices.shape[0]): if block_modality[i] == 1: assert len(txt_input_ids) > 0 img_input_ids.append(interleaved_batch.input_ids[start_positions[i]:end_positions[i]]) img_pos_ids.append(len(txt_input_ids) - 1) else: txt_input_ids.append(interleaved_batch.input_ids[start_positions[i]:end_positions[i]]) return cls(img_input_ids=img_input_ids, txt_input_ids=txt_input_ids, img_pos_ids=torch.tensor(img_pos_ids), batch_indices=batch_indices, start_positions=start_positions, end_positions=end_positions, raw_data=interleaved_batch) def to_list(self): txt_idx = 0 img_idx = 0 has_added_txt = False data = [] modalities = [] while txt_idx < len(self.txt_input_ids) or img_idx < len(self.img_input_ids): if not has_added_txt and txt_idx < len(self.txt_input_ids): data.append(self.txt_input_ids[txt_idx]) modalities.append(0) has_added_txt = True elif img_idx < len(self.img_input_ids) and self.img_pos_ids[img_idx] == txt_idx: data.append(self.img_input_ids[img_idx]) modalities.append(1) img_idx += 1 else: has_added_txt = False txt_idx += 1 return data, modalities