Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# pyre-unsafe | |
import copy | |
import inspect | |
import warnings | |
from typing import Any, List, Optional, Tuple, TypeVar, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from .device_utils import Device, make_device | |
class TensorAccessor(nn.Module): | |
""" | |
A helper class to be used with the __getitem__ method. This can be used for | |
getting/setting the values for an attribute of a class at one particular | |
index. This is useful when the attributes of a class are batched tensors | |
and one element in the batch needs to be modified. | |
""" | |
def __init__(self, class_object, index: Union[int, slice]) -> None: | |
""" | |
Args: | |
class_object: this should be an instance of a class which has | |
attributes which are tensors representing a batch of | |
values. | |
index: int/slice, an index indicating the position in the batch. | |
In __setattr__ and __getattr__ only the value of class | |
attributes at this index will be accessed. | |
""" | |
self.__dict__["class_object"] = class_object | |
self.__dict__["index"] = index | |
def __setattr__(self, name: str, value: Any): | |
""" | |
Update the attribute given by `name` to the value given by `value` | |
at the index specified by `self.index`. | |
Args: | |
name: str, name of the attribute. | |
value: value to set the attribute to. | |
""" | |
v = getattr(self.class_object, name) | |
if not torch.is_tensor(v): | |
msg = "Can only set values on attributes which are tensors; got %r" | |
raise AttributeError(msg % type(v)) | |
# Convert the attribute to a tensor if it is not a tensor. | |
if not torch.is_tensor(value): | |
value = torch.tensor(value, device=v.device, dtype=v.dtype, requires_grad=v.requires_grad) | |
# Check the shapes match the existing shape and the shape of the index. | |
if v.dim() > 1 and value.dim() > 1 and value.shape[1:] != v.shape[1:]: | |
msg = "Expected value to have shape %r; got %r" | |
raise ValueError(msg % (v.shape, value.shape)) | |
if v.dim() == 0 and isinstance(self.index, slice) and len(value) != len(self.index): | |
msg = "Expected value to have len %r; got %r" | |
raise ValueError(msg % (len(self.index), len(value))) | |
self.class_object.__dict__[name][self.index] = value | |
def __getattr__(self, name: str): | |
""" | |
Return the value of the attribute given by "name" on self.class_object | |
at the index specified in self.index. | |
Args: | |
name: string of the attribute name | |
""" | |
if hasattr(self.class_object, name): | |
return self.class_object.__dict__[name][self.index] | |
else: | |
msg = "Attribute %s not found on %r" | |
return AttributeError(msg % (name, self.class_object.__name__)) | |
BROADCAST_TYPES = (float, int, list, tuple, torch.Tensor, np.ndarray) | |
class TensorProperties(nn.Module): | |
""" | |
A mix-in class for storing tensors as properties with helper methods. | |
""" | |
def __init__(self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs) -> None: | |
""" | |
Args: | |
dtype: data type to set for the inputs | |
device: Device (as str or torch.device) | |
kwargs: any number of keyword arguments. Any arguments which are | |
of type (float/int/list/tuple/tensor/array) are broadcasted and | |
other keyword arguments are set as attributes. | |
""" | |
super().__init__() | |
self.device = make_device(device) | |
self._N = 0 | |
if kwargs is not None: | |
# broadcast all inputs which are float/int/list/tuple/tensor/array | |
# set as attributes anything else e.g. strings, bools | |
args_to_broadcast = {} | |
for k, v in kwargs.items(): | |
if v is None or isinstance(v, (str, bool)): | |
setattr(self, k, v) | |
elif isinstance(v, BROADCAST_TYPES): | |
args_to_broadcast[k] = v | |
else: | |
msg = "Arg %s with type %r is not broadcastable" | |
warnings.warn(msg % (k, type(v))) | |
names = args_to_broadcast.keys() | |
# convert from type dict.values to tuple | |
values = tuple(v for v in args_to_broadcast.values()) | |
if len(values) > 0: | |
broadcasted_values = convert_to_tensors_and_broadcast(*values, device=device) | |
# Set broadcasted values as attributes on self. | |
for i, n in enumerate(names): | |
setattr(self, n, broadcasted_values[i]) | |
if self._N == 0: | |
self._N = broadcasted_values[i].shape[0] | |
def __len__(self) -> int: | |
return self._N | |
def isempty(self) -> bool: | |
return self._N == 0 | |
def __getitem__(self, index: Union[int, slice]) -> TensorAccessor: | |
""" | |
Args: | |
index: an int or slice used to index all the fields. | |
Returns: | |
if `index` is an index int/slice return a TensorAccessor class | |
with getattribute/setattribute methods which return/update the value | |
at the index in the original class. | |
""" | |
if isinstance(index, (int, slice)): | |
return TensorAccessor(class_object=self, index=index) | |
msg = "Expected index of type int or slice; got %r" | |
raise ValueError(msg % type(index)) | |
# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently. | |
def to(self, device: Device = "cpu") -> "TensorProperties": | |
""" | |
In place operation to move class properties which are tensors to a | |
specified device. If self has a property "device", update this as well. | |
""" | |
device_ = make_device(device) | |
for k in dir(self): | |
v = getattr(self, k) | |
if k == "device": | |
setattr(self, k, device_) | |
if torch.is_tensor(v) and v.device != device_: | |
setattr(self, k, v.to(device_)) | |
return self | |
def cpu(self) -> "TensorProperties": | |
return self.to("cpu") | |
# pyre-fixme[14]: `cuda` overrides method defined in `Module` inconsistently. | |
def cuda(self, device: Optional[int] = None) -> "TensorProperties": | |
return self.to(f"cuda:{device}" if device is not None else "cuda") | |
def clone(self, other) -> "TensorProperties": | |
""" | |
Update the tensor properties of other with the cloned properties of self. | |
""" | |
for k in dir(self): | |
v = getattr(self, k) | |
if inspect.ismethod(v) or k.startswith("__") or type(v) is TypeVar: | |
continue | |
if torch.is_tensor(v): | |
v_clone = v.clone() | |
else: | |
v_clone = copy.deepcopy(v) | |
setattr(other, k, v_clone) | |
return other | |
def gather_props(self, batch_idx) -> "TensorProperties": | |
""" | |
This is an in place operation to reformat all tensor class attributes | |
based on a set of given indices using torch.gather. This is useful when | |
attributes which are batched tensors e.g. shape (N, 3) need to be | |
multiplied with another tensor which has a different first dimension | |
e.g. packed vertices of shape (V, 3). | |
Example | |
.. code-block:: python | |
self.specular_color = (N, 3) tensor of specular colors for each mesh | |
A lighting calculation may use | |
.. code-block:: python | |
verts_packed = meshes.verts_packed() # (V, 3) | |
To multiply these two tensors the batch dimension needs to be the same. | |
To achieve this we can do | |
.. code-block:: python | |
batch_idx = meshes.verts_packed_to_mesh_idx() # (V) | |
This gives index of the mesh for each vertex in verts_packed. | |
.. code-block:: python | |
self.gather_props(batch_idx) | |
self.specular_color = (V, 3) tensor with the specular color for | |
each packed vertex. | |
torch.gather requires the index tensor to have the same shape as the | |
input tensor so this method takes care of the reshaping of the index | |
tensor to use with class attributes with arbitrary dimensions. | |
Args: | |
batch_idx: shape (B, ...) where `...` represents an arbitrary | |
number of dimensions | |
Returns: | |
self with all properties reshaped. e.g. a property with shape (N, 3) | |
is transformed to shape (B, 3). | |
""" | |
# Iterate through the attributes of the class which are tensors. | |
for k in dir(self): | |
v = getattr(self, k) | |
if torch.is_tensor(v): | |
if v.shape[0] > 1: | |
# There are different values for each batch element | |
# so gather these using the batch_idx. | |
# First clone the input batch_idx tensor before | |
# modifying it. | |
_batch_idx = batch_idx.clone() | |
idx_dims = _batch_idx.shape | |
tensor_dims = v.shape | |
if len(idx_dims) > len(tensor_dims): | |
msg = "batch_idx cannot have more dimensions than %s. " | |
msg += "got shape %r and %s has shape %r" | |
raise ValueError(msg % (k, idx_dims, k, tensor_dims)) | |
if idx_dims != tensor_dims: | |
# To use torch.gather the index tensor (_batch_idx) has | |
# to have the same shape as the input tensor. | |
new_dims = len(tensor_dims) - len(idx_dims) | |
new_shape = idx_dims + (1,) * new_dims | |
expand_dims = (-1,) + tensor_dims[1:] | |
_batch_idx = _batch_idx.view(*new_shape) | |
_batch_idx = _batch_idx.expand(*expand_dims) | |
v = v.gather(0, _batch_idx) | |
setattr(self, k, v) | |
return self | |
def format_tensor(input, dtype: torch.dtype = torch.float32, device: Device = "cpu") -> torch.Tensor: | |
""" | |
Helper function for converting a scalar value to a tensor. | |
Args: | |
input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor | |
dtype: data type for the input | |
device: Device (as str or torch.device) on which the tensor should be placed. | |
Returns: | |
input_vec: torch tensor with optional added batch dimension. | |
""" | |
device_ = make_device(device) | |
if not torch.is_tensor(input): | |
input = torch.tensor(input, dtype=dtype, device=device_) | |
if input.dim() == 0: | |
input = input.view(1) | |
if input.device == device_: | |
return input | |
input = input.to(device=device) | |
return input | |
def convert_to_tensors_and_broadcast(*args, dtype: torch.dtype = torch.float32, device: Device = "cpu"): | |
""" | |
Helper function to handle parsing an arbitrary number of inputs (*args) | |
which all need to have the same batch dimension. | |
The output is a list of tensors. | |
Args: | |
*args: an arbitrary number of inputs | |
Each of the values in `args` can be one of the following | |
- Python scalar | |
- Torch scalar | |
- Torch tensor of shape (N, K_i) or (1, K_i) where K_i are | |
an arbitrary number of dimensions which can vary for each | |
value in args. In this case each input is broadcast to a | |
tensor of shape (N, K_i) | |
dtype: data type to use when creating new tensors. | |
device: torch device on which the tensors should be placed. | |
Output: | |
args: A list of tensors of shape (N, K_i) | |
""" | |
# Convert all inputs to tensors with a batch dimension | |
args_1d = [format_tensor(c, dtype, device) for c in args] | |
# Find broadcast size | |
sizes = [c.shape[0] for c in args_1d] | |
N = max(sizes) | |
args_Nd = [] | |
for c in args_1d: | |
if c.shape[0] != 1 and c.shape[0] != N: | |
msg = "Got non-broadcastable sizes %r" % sizes | |
raise ValueError(msg) | |
# Expand broadcast dim and keep non broadcast dims the same size | |
expand_sizes = (N,) + (-1,) * len(c.shape[1:]) | |
args_Nd.append(c.expand(*expand_sizes)) | |
return args_Nd | |
def ndc_grid_sample( | |
input: torch.Tensor, grid_ndc: torch.Tensor, *, align_corners: bool = False, **grid_sample_kwargs | |
) -> torch.Tensor: | |
""" | |
Samples a tensor `input` of shape `(B, dim, H, W)` at 2D locations | |
specified by a tensor `grid_ndc` of shape `(B, ..., 2)` using | |
the `torch.nn.functional.grid_sample` function. | |
`grid_ndc` is specified in PyTorch3D NDC coordinate frame. | |
Args: | |
input: The tensor of shape `(B, dim, H, W)` to be sampled. | |
grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of | |
2D locations at which `input` is sampled. | |
See [1] for a detailed description of the NDC coordinates. | |
align_corners: Forwarded to the `torch.nn.functional.grid_sample` | |
call. See its docstring. | |
grid_sample_kwargs: Additional arguments forwarded to the | |
`torch.nn.functional.grid_sample` call. See the corresponding | |
docstring for a listing of the corresponding arguments. | |
Returns: | |
sampled_input: A tensor of shape `(B, dim, ...)` containing the samples | |
of `input` at 2D locations `grid_ndc`. | |
References: | |
[1] https://pytorch3d.org/docs/cameras | |
""" | |
batch, *spatial_size, pt_dim = grid_ndc.shape | |
if batch != input.shape[0]: | |
raise ValueError("'input' and 'grid_ndc' have to have the same batch size.") | |
if input.ndim != 4: | |
raise ValueError("'input' has to be a 4-dimensional Tensor.") | |
if pt_dim != 2: | |
raise ValueError("The last dimension of 'grid_ndc' has to be == 2.") | |
grid_ndc_flat = grid_ndc.reshape(batch, -1, 1, 2) | |
# pyre-fixme[6]: For 2nd param expected `Tuple[int, int]` but got `Size`. | |
grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:]) | |
sampled_input_flat = torch.nn.functional.grid_sample( | |
input, grid_flat, align_corners=align_corners, **grid_sample_kwargs | |
) | |
sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size]) | |
return sampled_input | |
def ndc_to_grid_sample_coords(xy_ndc: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor: | |
""" | |
Convert from the PyTorch3D's NDC coordinates to | |
`torch.nn.functional.grid_sampler`'s coordinates. | |
Args: | |
xy_ndc: Tensor of shape `(..., 2)` containing 2D points in the | |
PyTorch3D's NDC coordinates. | |
image_size_hw: A tuple `(image_height, image_width)` denoting the | |
height and width of the image tensor to sample. | |
Returns: | |
xy_grid_sample: Tensor of shape `(..., 2)` containing 2D points in the | |
`torch.nn.functional.grid_sample` coordinates. | |
""" | |
if len(image_size_hw) != 2 or any(s <= 0 for s in image_size_hw): | |
raise ValueError("'image_size_hw' has to be a 2-tuple of positive integers") | |
aspect = min(image_size_hw) / max(image_size_hw) | |
xy_grid_sample = -xy_ndc # first negate the coords | |
if image_size_hw[0] >= image_size_hw[1]: | |
xy_grid_sample[..., 1] *= aspect | |
else: | |
xy_grid_sample[..., 0] *= aspect | |
return xy_grid_sample | |
def parse_image_size(image_size: Union[List[int], Tuple[int, int], int]) -> Tuple[int, int]: | |
""" | |
Args: | |
image_size: A single int (for square images) or a tuple/list of two ints. | |
Returns: | |
A tuple of two ints. | |
Throws: | |
ValueError if got more than two ints, any negative numbers or non-ints. | |
""" | |
if not isinstance(image_size, (tuple, list)): | |
return (image_size, image_size) | |
if len(image_size) != 2: | |
raise ValueError("Image size can only be a tuple/list of (H, W)") | |
if not all(i > 0 for i in image_size): | |
raise ValueError("Image sizes must be greater than 0; got %d, %d" % image_size) | |
if not all(isinstance(i, int) for i in image_size): | |
raise ValueError("Image sizes must be integers; got %f, %f" % image_size) | |
return tuple(image_size) | |