|
from __future__ import annotations |
|
from tensordict import tensorclass |
|
import torch |
|
from torch import nn |
|
from typing import Optional |
|
from unidisc.utils.tensor_utils import get_contiguous_blocks, get_interleaved_indices |
|
from tensordict import TensorDict |
|
|
|
@tensorclass |
|
class InterleavedBatch: |
|
input_ids: torch.Tensor |
|
modality: torch.Tensor |
|
sample_ids: torch.Tensor |
|
attention_mask: Optional[torch.Tensor] = None |
|
|
|
def to_ragged_batch(self): |
|
data = [] |
|
batch_indices, start_positions, end_positions = get_contiguous_blocks(self.sample_ids) |
|
first_sample_ids = self.sample_ids[batch_indices, start_positions] |
|
self.auto_batch_size_() |
|
for i in range(batch_indices.shape[0]): |
|
if first_sample_ids[i] == -1: |
|
continue |
|
data.append(self[batch_indices[i], start_positions[i]:end_positions[i]]) |
|
|
|
return TensorDict.lazy_stack(data, dim=0) |
|
|
|
def to_elements(self): |
|
data = self.to_ragged_batch() |
|
new_data = [] |
|
for i in range(data.shape[0]): |
|
new_data.append(InterleavedElement.from_raw(data[i])) |
|
return TensorDict.lazy_stack(new_data, dim=0) |
|
|
|
@classmethod |
|
def custom_from_dict(cls, data: TensorDict): |
|
new_dict = {} |
|
for field in cls.fields(): |
|
if field.name in data: |
|
new_dict[field.name] = data[field.name] |
|
|
|
return cls(**new_dict) |
|
|
|
|
|
@tensorclass |
|
class InterleavedElement: |
|
txt_input_ids: Optional[list[torch.Tensor]] = None |
|
img_input_ids: Optional[list[torch.Tensor]] = None |
|
txt: Optional[torch.Tensor] = None |
|
img: Optional[torch.Tensor] = None |
|
img_pos_ids: Optional[torch.Tensor] = None |
|
batch_indices: Optional[torch.Tensor] = None |
|
start_positions: Optional[torch.Tensor] = None |
|
end_positions: Optional[torch.Tensor] = None |
|
raw_data: Optional[InterleavedBatch] = None |
|
|
|
@classmethod |
|
def from_raw(cls, interleaved_batch: InterleavedBatch): |
|
batch_indices, start_positions, end_positions = get_contiguous_blocks(interleaved_batch.modality[None]) |
|
block_modality = interleaved_batch.modality[start_positions] |
|
|
|
img_input_ids = [] |
|
txt_input_ids = [] |
|
img_pos_ids = [] |
|
for i in range(batch_indices.shape[0]): |
|
if block_modality[i] == 1: |
|
assert len(txt_input_ids) > 0 |
|
img_input_ids.append(interleaved_batch.input_ids[start_positions[i]:end_positions[i]]) |
|
img_pos_ids.append(len(txt_input_ids) - 1) |
|
else: |
|
txt_input_ids.append(interleaved_batch.input_ids[start_positions[i]:end_positions[i]]) |
|
|
|
return cls(img_input_ids=img_input_ids, txt_input_ids=txt_input_ids, img_pos_ids=torch.tensor(img_pos_ids), batch_indices=batch_indices, start_positions=start_positions, end_positions=end_positions, raw_data=interleaved_batch) |
|
|
|
def to_list(self): |
|
txt_idx = 0 |
|
img_idx = 0 |
|
has_added_txt = False |
|
data = [] |
|
modalities = [] |
|
while txt_idx < len(self.txt_input_ids) or img_idx < len(self.img_input_ids): |
|
if not has_added_txt and txt_idx < len(self.txt_input_ids): |
|
data.append(self.txt_input_ids[txt_idx]) |
|
modalities.append(0) |
|
has_added_txt = True |
|
elif img_idx < len(self.img_input_ids) and self.img_pos_ids[img_idx] == txt_idx: |
|
data.append(self.img_input_ids[img_idx]) |
|
modalities.append(1) |
|
img_idx += 1 |
|
else: |
|
has_added_txt = False |
|
txt_idx += 1 |
|
|
|
return data, modalities |