Spaces:
Running
Running
# Copyright 2024 Bytedance Ltd. and/or its affiliates | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
Implement base data transfer protocol between any two functions, modules. | |
We can subclass Protocol to define more detailed batch info with specific keys | |
""" | |
import copy | |
import io | |
import pickle | |
from collections import defaultdict | |
from dataclasses import dataclass, field | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import ray | |
import torch | |
from numpy.typing import NDArray | |
from tensordict import TensorDict | |
from torch.distributed import ProcessGroup | |
from torch.utils.data import DataLoader | |
from .utils.py_functional import union_two_dict | |
try: | |
import tensordict | |
tensordict.set_lazy_legacy(False).set() | |
except Exception: | |
pass | |
__all__ = ["DataProto", "union_tensor_dict"] | |
def pad_dataproto_to_divisor(data: "DataProto", size_divisor: int) -> Tuple["DataProto", int]: | |
"""Pad a DataProto to size divisible by size_divisor | |
Args: | |
data (DataProto): the unpadded DataProto | |
size_divisor (int): size divisor | |
Returns: | |
data (DataProto): the padded DataProto | |
pad_size (int) | |
""" | |
assert isinstance(data, DataProto), "data must be a DataProto" | |
if len(data) % size_divisor != 0: | |
pad_size = size_divisor - len(data) % size_divisor | |
padding_protos = [] | |
remaining_pad = pad_size | |
while remaining_pad > 0: | |
take_size = min(remaining_pad, len(data)) | |
padding_protos.append(data[:take_size]) | |
remaining_pad -= take_size | |
data_padded = DataProto.concat([data] + padding_protos) | |
else: | |
pad_size = 0 | |
data_padded = data | |
return data_padded, pad_size | |
def unpad_dataproto(data: "DataProto", pad_size: int) -> "DataProto": | |
if pad_size != 0: | |
data = data[:-pad_size] | |
return data | |
def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: | |
"""Union two tensordicts.""" | |
if tensor_dict1.batch_size != tensor_dict2.batch_size: | |
raise ValueError( | |
f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}" | |
) | |
for key in tensor_dict2.keys(): | |
if key in tensor_dict1 and not torch.equal(tensor_dict1[key], tensor_dict2[key]): | |
raise ValueError(f"Key already exists: {key}.") | |
tensor_dict1[key] = tensor_dict2[key] | |
return tensor_dict1 | |
def union_numpy_dict(tensor_dict1: Dict[str, NDArray], tensor_dict2: Dict[str, NDArray]) -> Dict[str, NDArray]: | |
for key in tensor_dict2.keys(): | |
if key in tensor_dict1: | |
assert isinstance(tensor_dict2[key], np.ndarray) | |
assert isinstance(tensor_dict1[key], np.ndarray) | |
if not np.all(tensor_dict1[key] == tensor_dict2[key]): | |
raise ValueError(f"Key already exists: {key}.") | |
tensor_dict1[key] = tensor_dict2[key] | |
return tensor_dict1 | |
def batch_collate(features: List[Dict[str, Any]]) -> Dict[str, List[Any]]: | |
if len(features) == 0: | |
return {} | |
batch_features = defaultdict(list) | |
for feature in features: | |
for key, value in feature.items(): | |
batch_features[key].append(value) | |
return batch_features | |
def fold_batch_dim(data: "DataProto", new_batch_size: int): | |
""" | |
Fold a batch dim from [bsz, xxx] into [new_bsz, bsz // new_bsz, xxx] | |
""" | |
batch_size = data.batch.batch_size[0] | |
assert batch_size % new_batch_size == 0 | |
tensor: TensorDict = data.batch | |
non_tensor = data.non_tensor_batch | |
tensor = tensor.view(new_batch_size, -1) | |
tensor.auto_batch_size_(batch_dims=1) | |
for key, value in non_tensor.items(): | |
non_tensor[key] = np.reshape(value, newshape=(new_batch_size, -1, *value.shape[1:])) | |
return DataProto(batch=tensor, non_tensor_batch=non_tensor, meta_info=data.meta_info) | |
def collate_fn(data_items: list["DataProtoItem"]): | |
batch = [] | |
non_tensor_batch = [] | |
for data in data_items: | |
batch.append(data.batch) | |
non_tensor_batch.append(data.non_tensor_batch) | |
batch = torch.stack(batch).contiguous() | |
non_tensor_batch = batch_collate(non_tensor_batch) | |
non_tensor_batch = {key: np.array(value, dtype=object) for key, value in non_tensor_batch.items()} | |
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch) | |
class DataProtoItem: | |
batch: Optional[TensorDict] = None | |
non_tensor_batch: Dict[str, NDArray] = field(default_factory=dict) | |
meta_info: Dict[str, Any] = field(default_factory=dict) | |
class DataProto: | |
""" | |
A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. | |
It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. | |
TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the | |
same batch size should be put inside batch. | |
""" | |
batch: Optional[TensorDict] = None | |
non_tensor_batch: Dict[str, NDArray] = field(default_factory=dict) | |
meta_info: Dict[str, Any] = field(default_factory=dict) | |
def __post_init__(self): | |
self.check_consistency() # perform necessary checking | |
def __len__(self) -> int: | |
if self.batch is not None: | |
return self.batch.batch_size[0] | |
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0: | |
pivot_key = list(self.non_tensor_batch.keys())[0] | |
return self.non_tensor_batch[pivot_key].shape[0] | |
else: | |
return 0 | |
def __getitem__(self, item: Union[int, slice]) -> Union["DataProto", "DataProtoItem"]: | |
tensor_data = self.batch[item] | |
non_tensor_data = {key: value[item] for key, value in self.non_tensor_batch.items()} | |
return_type = DataProto if isinstance(item, slice) else DataProtoItem | |
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) | |
def __getstate__(self) -> Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]: | |
buffer = io.BytesIO() | |
if self.batch is not None: | |
self.batch: TensorDict = self.batch.contiguous() | |
self.batch: TensorDict = self.batch.consolidate() | |
torch.save(self.batch, buffer) | |
buffer_bytes = buffer.getvalue() | |
return buffer_bytes, self.non_tensor_batch, self.meta_info | |
def __setstate__(self, data: Tuple[bytes, Dict[str, NDArray], Dict[str, Any]]) -> None: | |
batch_deserialized_bytes, non_tensor_batch, meta_info = data | |
batch_deserialized = io.BytesIO(batch_deserialized_bytes) | |
batch = torch.load(batch_deserialized, weights_only=False, map_location="cpu") | |
self.batch = batch | |
self.non_tensor_batch = non_tensor_batch | |
self.meta_info = meta_info | |
def save_to_disk(self, filepath: str) -> None: | |
with open(filepath, "wb") as f: | |
pickle.dump(self, f) | |
def load_from_disk(filepath: str) -> "DataProto": | |
with open(filepath, "rb") as f: | |
data = pickle.load(f) | |
return data | |
def print_size(self, prefix: str = "") -> None: | |
size_of_tensordict = 0 | |
if self.batch is not None: | |
for tensor in self.batch.values(): | |
if isinstance(tensor, torch.Tensor): | |
size_of_tensordict += tensor.element_size() * tensor.numel() | |
size_of_numpy_array = 0 | |
for value in self.non_tensor_batch.values(): | |
size_of_numpy_array += value.nbytes | |
size_of_numpy_array /= 1024**3 | |
size_of_tensordict /= 1024**3 | |
message = f"Size of tensordict: {size_of_tensordict} GB, size of non_tensor_batch: {size_of_numpy_array} GB." | |
print({prefix}, {message}) | |
def check_consistency(self): | |
"""Check the consistency of the DataProto. Mainly for batch and non_tensor_batch | |
We expose this function as a public one so that user can call themselves directly | |
""" | |
if self.batch is not None: | |
assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1" | |
if self.batch is not None and len(self.non_tensor_batch) != 0: | |
# TODO: we can actually lift this restriction if needed | |
assert len(self.batch.batch_size) == 1, "only support num_batch_dims=1 when non_tensor_batch is not empty." | |
batch_size = self.batch.batch_size[0] | |
for key, value in self.non_tensor_batch.items(): | |
assert len(value) == batch_size, f"key {key} length {len(value)} is not equal to bsz {batch_size}." | |
def from_single_dict( | |
cls, | |
data: Dict[str, Union[torch.Tensor, NDArray]], | |
meta_info: Optional[Dict[str, Any]] = None, | |
) -> "DataProto": | |
tensors, non_tensors = {}, {} | |
for key, value in data.items(): | |
if isinstance(value, torch.Tensor): | |
tensors[key] = value | |
elif isinstance(value, np.ndarray): | |
non_tensors[key] = value | |
else: | |
raise ValueError(f"Unsupported type in data {type(value)}") | |
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) | |
def from_dict( | |
cls, | |
tensors: Dict[str, torch.Tensor], | |
non_tensors: Dict[str, NDArray] = None, | |
meta_info: Optional[Dict[str, Any]] = None, | |
num_batch_dims: int = 1, | |
) -> "DataProto": | |
"""Create a DataProto from a dict of tensors. This assumes that | |
1. All the tensor in tensors have the same dim0 | |
2. Only dim0 is the batch dim | |
""" | |
assert len(tensors) > 0, "tensors must not be empty" | |
assert num_batch_dims > 0, "num_batch_dims must be greater than zero" | |
if non_tensors is not None: | |
assert num_batch_dims == 1, "only support num_batch_dims=1 when non_tensors is not None." | |
meta_info = meta_info or {} | |
non_tensors = non_tensors or {} | |
assert isinstance(non_tensors, dict), "non_tensors should be a dictionary." | |
# get and check batch size | |
batch_size = None | |
pivot_key = None | |
for key, tensor in tensors.items(): | |
if batch_size is None: | |
batch_size = tensor.shape[:num_batch_dims] | |
pivot_key = key | |
else: | |
current_batch = tensor.shape[:num_batch_dims] | |
assert batch_size == current_batch, ( | |
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. " | |
f"Got {pivot_key} has {batch_size}, {key} has {current_batch}" | |
) | |
tensor_dict = TensorDict(source=tensors, batch_size=batch_size) | |
return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info) | |
def to(self, device: torch.device) -> "DataProto": | |
"""move the batch to device | |
Args: | |
device (torch.device, str): torch device | |
Returns: | |
DataProto: the current DataProto | |
""" | |
if self.batch is not None: | |
self.batch = self.batch.to(device) | |
return self | |
def select( | |
self, | |
batch_keys: Optional[List[str]] = None, | |
non_tensor_batch_keys: Optional[List[str]] = None, | |
meta_info_keys: Optional[List[str]] = None, | |
deepcopy: bool = False, | |
) -> "DataProto": | |
"""Select a subset of the DataProto via batch_keys and meta_info_keys | |
Args: | |
batch_keys (list, optional): a list of strings indicating the keys in batch to select | |
meta_info_keys (list, optional): a list of keys indicating the meta info to select | |
Returns: | |
DataProto: the DataProto with the selected batch_keys and meta_info_keys | |
""" | |
# TODO (zhangchi.usc1992) whether to copy | |
if batch_keys is not None: | |
batch_keys = tuple(batch_keys) | |
sub_batch = self.batch.select(*batch_keys) | |
else: | |
sub_batch = self.batch | |
if non_tensor_batch_keys is not None: | |
non_tensor_batch = {k: v for k, v in self.non_tensor_batch.items() if k in non_tensor_batch_keys} | |
else: | |
non_tensor_batch = self.non_tensor_batch | |
if deepcopy: | |
non_tensor_batch = copy.deepcopy(non_tensor_batch) | |
if meta_info_keys is not None: | |
sub_meta_info = {k: v for k, v in self.meta_info.items() if k in meta_info_keys} | |
else: | |
sub_meta_info = self.meta_info | |
if deepcopy: | |
sub_meta_info = copy.deepcopy(sub_meta_info) | |
return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info) | |
def pop( | |
self, | |
batch_keys: Optional[List[str]] = None, | |
non_tensor_batch_keys: Optional[List[str]] = None, | |
meta_info_keys: Optional[List[str]] = None, | |
) -> "DataProto": | |
"""Pop a subset of the DataProto via `batch_keys` and `meta_info_keys` | |
Args: | |
batch_keys (list, optional): a list of strings indicating the keys in batch to pop | |
meta_info_keys (list, optional): a list of keys indicating the meta info to pop | |
Returns: | |
DataProto: the DataProto with the poped batch_keys and meta_info_keys | |
""" | |
assert batch_keys is not None | |
non_tensor_batch_keys = non_tensor_batch_keys or [] | |
meta_info_keys = meta_info_keys or [] | |
tensors = {} | |
for key in batch_keys: | |
tensors[key] = self.batch.pop(key) | |
non_tensors = {} | |
for key in non_tensor_batch_keys: | |
non_tensors[key] = self.non_tensor_batch.pop(key) | |
meta_info = {} | |
for key in meta_info_keys: | |
meta_info[key] = self.meta_info.pop(key) | |
return DataProto.from_dict(tensors=tensors, non_tensors=non_tensors, meta_info=meta_info) | |
def rename( | |
self, old_keys: Optional[Union[str, List[str]]] = None, new_keys: Optional[Union[str, List[str]]] = None | |
) -> "DataProto": | |
""" | |
Note that this function only rename the key in the batch | |
""" | |
def validate_input(keys): | |
if keys is not None: | |
if isinstance(keys, str): | |
keys = [keys] | |
elif isinstance(keys, list): | |
pass | |
else: | |
raise TypeError(f"keys must be a list or a string, but got {type(keys)}") | |
return keys | |
old_keys = validate_input(old_keys) | |
new_keys = validate_input(new_keys) | |
if len(new_keys) != len(old_keys): | |
raise ValueError( | |
f"new_keys and old_keys must have the same length, but got {len(new_keys)} and {len(old_keys)}" | |
) | |
self.batch.rename_key_(tuple(old_keys), tuple(new_keys)) | |
return self | |
def union(self, other: "DataProto") -> "DataProto": | |
"""Union with another DataProto. Union batch and meta_info separately. | |
Throw an error if | |
- there are conflict keys in batch and they are not equal | |
- the batch size of two data batch is not the same | |
- there are conflict keys in meta_info and they are not the same. | |
Args: | |
other (DataProto): another DataProto to union | |
Returns: | |
DataProto: the DataProto after union | |
""" | |
self.batch = union_tensor_dict(self.batch, other.batch) | |
self.non_tensor_batch = union_numpy_dict(self.non_tensor_batch, other.non_tensor_batch) | |
self.meta_info = union_two_dict(self.meta_info, other.meta_info) | |
return self | |
def make_iterator( | |
self, mini_batch_size: int, epochs: int, seed: int = None, dataloader_kwargs: Dict[str, Any] = None | |
): | |
"""Make an iterator from the DataProto. This is built upon that TensorDict can be used as a normal Pytorch | |
dataset. See https://pytorch.org/tensordict/tutorials/data_fashion for more details. | |
Args: | |
mini_batch_size (int): mini-batch size when iterating the dataset. We require that | |
``batch.batch_size[0] % mini_batch_size == 0`` | |
epochs (int): number of epochs when iterating the dataset. | |
dataloader_kwargs: internally, it returns a DataLoader over the batch. | |
The dataloader_kwargs is the kwargs passed to the DataLoader | |
Returns: | |
Iterator: an iterator that yields a mini-batch data at a time. The total number of iteration steps is | |
``self.batch.batch_size * epochs // mini_batch_size`` | |
""" | |
assert self.batch.batch_size[0] % mini_batch_size == 0, f"{self.batch.batch_size[0]} % {mini_batch_size} != 0" | |
# we can directly create a dataloader from TensorDict | |
if dataloader_kwargs is None: | |
dataloader_kwargs = {} | |
if seed is not None: | |
generator = torch.Generator() | |
generator.manual_seed(seed) | |
else: | |
generator = None | |
assert isinstance(dataloader_kwargs, Dict) | |
train_dataloader = DataLoader( | |
dataset=self, batch_size=mini_batch_size, collate_fn=collate_fn, generator=generator, **dataloader_kwargs | |
) | |
def get_data(): | |
for _ in range(epochs): | |
for d in train_dataloader: | |
d.meta_info = self.meta_info | |
yield d | |
return iter(get_data()) | |
def chunk(self, chunks: int) -> List["DataProto"]: | |
"""Split the batch among dim=0 into chunks. The meta_info is passed to each DataProto after split. | |
Args: | |
chunks (int): the number of chunks to split on dim=0 | |
Returns: | |
List[DataProto]: a list of DataProto after splitting | |
""" | |
assert len(self) % chunks == 0, ( | |
f"only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}." | |
) | |
if self.batch is not None: | |
batch_lst = self.batch.chunk(chunks=chunks, dim=0) | |
else: | |
batch_lst = [None for _ in range(chunks)] | |
non_tensor_batch_lst = [{} for _ in range(chunks)] | |
for key, value in self.non_tensor_batch.items(): | |
assert isinstance(value, np.ndarray) | |
non_tensor_lst = np.array_split(value, chunks) | |
assert len(non_tensor_lst) == chunks | |
for i in range(chunks): | |
non_tensor_batch_lst[i][key] = non_tensor_lst[i] | |
output = [] | |
for i in range(chunks): | |
output.append( | |
DataProto(batch=batch_lst[i], non_tensor_batch=non_tensor_batch_lst[i], meta_info=self.meta_info) | |
) | |
return output | |
def split(self, split_size: int) -> List["DataProto"]: | |
chunks = len(self) // split_size | |
return self.chunk(chunks) | |
def concat(data: List["DataProto"]) -> "DataProto": | |
"""Concat a list of DataProto. The batch is concatenated among dim=0. | |
The meta_info is assumed to be identical and will use the first one. | |
Args: | |
data (List[DataProto]): list of DataProto | |
Returns: | |
DataProto: concatenated DataProto | |
""" | |
batch_lst = [batch.batch for batch in data] | |
if batch_lst[0] is not None: | |
new_batch = torch.cat(batch_lst, dim=0) | |
else: | |
new_batch = None | |
non_tensor_batch = batch_collate([d.non_tensor_batch for d in data]) | |
for key, value in non_tensor_batch.items(): | |
non_tensor_batch[key] = np.concatenate(value, axis=0) | |
return DataProto(batch=new_batch, non_tensor_batch=non_tensor_batch, meta_info=data[0].meta_info) | |
def reorder(self, indices: torch.Tensor) -> None: | |
""" | |
Note that this operation is in-place | |
""" | |
indices_np = indices.detach().numpy() | |
self.batch = self.batch[indices] | |
self.non_tensor_batch = {key: value[indices_np] for key, value in self.non_tensor_batch.items()} | |
def repeat(self, repeat_times: int = 2, interleave: bool = True) -> "DataProto": | |
""" | |
Repeat the batch data a specified number of times. | |
Args: | |
repeat_times (int): Number of times to repeat the data. | |
interleave (bool): Whether to interleave the repeated data. | |
Returns: | |
DataProto: A new DataProto with repeated data. | |
""" | |
if self.batch is not None: | |
if interleave: | |
# Interleave the data | |
repeated_tensors = { | |
key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() | |
} | |
else: | |
# Stack the data | |
repeated_tensors = { | |
key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) | |
for key, tensor in self.batch.items() | |
} | |
repeated_batch = TensorDict( | |
source=repeated_tensors, | |
batch_size=(self.batch.batch_size[0] * repeat_times,), | |
) | |
else: | |
repeated_batch = None | |
repeated_non_tensor_batch = {} | |
for key, value in self.non_tensor_batch.items(): | |
if interleave: | |
repeated_non_tensor_batch[key] = np.repeat(value, repeat_times, axis=0) | |
else: | |
repeated_non_tensor_batch[key] = np.tile(value, (repeat_times,) + (1,) * (value.ndim - 1)) | |
return DataProto( | |
batch=repeated_batch, | |
non_tensor_batch=repeated_non_tensor_batch, | |
meta_info=self.meta_info, | |
) | |
class DataProtoFuture: | |
""" | |
DataProtoFuture aims to eliminate actual data fetching on driver. By doing so, the driver doesn't have to wait | |
for data so that asynchronous execution becomes possible. | |
DataProtoFuture contains a list of futures from another WorkerGroup of size world_size. | |
- collect_fn is a Callable that reduces the list of futures to a DataProto | |
- dispatch_fn is a Callable that partitions the DataProto into a list of DataProto of size world_size and then select | |
Potential issue: we can optimize dispatch_fn(collect_fn) such that only needed data is fetched on destination | |
- DataProtoFuture only supports directly passing from the output of a method to another input. You can't perform any | |
operation on the DataProtoFuture in driver. | |
""" | |
collect_fn: Callable | |
futures: List[ray.ObjectRef] | |
dispatch_fn: Callable = None | |
def concat(data: List[ray.ObjectRef]) -> "DataProtoFuture": | |
output = DataProtoFuture(collect_fn=DataProto.concat, futures=data) | |
return output | |
def chunk(self, chunks: int) -> List["DataProtoFuture"]: | |
from functools import partial | |
arg_future_lst = [] | |
for i in range(chunks): | |
# note that we can't directly pass i and chunks | |
def dispatch_fn(x, i, chunks): | |
return x.chunk(chunks=chunks)[i] | |
arg_future = DataProtoFuture( | |
collect_fn=self.collect_fn, dispatch_fn=partial(dispatch_fn, i=i, chunks=chunks), futures=self.futures | |
) | |
arg_future_lst.append(arg_future) | |
return arg_future_lst | |
def get(self): | |
outputs = ray.get(self.futures) # dp_size. | |
for output in outputs: | |
assert isinstance(output, DataProto) | |
outputs = self.collect_fn(outputs) # select dp, concat | |
if self.dispatch_fn is not None: | |
outputs = self.dispatch_fn(outputs) # split in batch dim, select using dp | |
return outputs | |
def allgather_dict_tensors( | |
tensors: Union[Dict[str, torch.Tensor], TensorDict], size: int, group: ProcessGroup, dim: int = 0 | |
) -> Union[Dict[str, torch.Tensor], TensorDict]: | |
""" | |
TODO: optimize this. | |
- We can use async ops | |
- We can use only one allgather | |
""" | |
if isinstance(tensors, TensorDict): | |
is_tensor_dict = True | |
tensors_as_dict = tensors.to_dict() | |
else: | |
tensors_as_dict = tensors | |
is_tensor_dict = False | |
output = {} | |
sorted_keys = sorted(tensors_as_dict.keys()) | |
for key in sorted_keys: | |
value = tensors_as_dict[key] | |
output[key] = [torch.empty_like(value) for _ in range(size)] | |
torch.distributed.all_gather(output[key], value, group=group, async_op=False) | |
output[key] = torch.cat(output[key], dim=dim) | |
if is_tensor_dict: | |
output = TensorDict(source=output, batch_size=tensors.batch_size[0] * size) | |
return output | |
def all_gather_data_proto(data: DataProto, size: int, group: ProcessGroup) -> None: | |
# Note that this is an inplace operator just like torch.distributed.all_gather | |
prev_device = data.batch.device | |
data.batch = data.batch.cuda(device=torch.cuda.current_device()) | |
data.batch = allgather_dict_tensors(data.batch.contiguous(), size=size, group=group, dim=0) | |
data.batch = data.batch.to(prev_device) | |
# all gather non_tensor_batch | |
all_non_tensor_batch = [None for _ in range(size)] | |
torch.distributed.all_gather_object(all_non_tensor_batch, data.non_tensor_batch, group=group) | |
data.non_tensor_batch = {k: np.concatenate([d[k] for d in all_non_tensor_batch]) for k in data.non_tensor_batch} | |