Spaces:
Running
on
L4
Running
on
L4
# -------------------------------------------------------- | |
# Adapted from: https://github.com/openai/point-e | |
# Licensed under the MIT License | |
# Copyright (c) 2022 OpenAI | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# -------------------------------------------------------- | |
from typing import Dict, Iterator | |
import torch | |
import torch.nn as nn | |
from .gaussian_diffusion import GaussianDiffusion | |
class PointCloudSampler: | |
""" | |
A wrapper around a model that produces conditional sample tensors. | |
""" | |
def __init__( | |
self, | |
model: nn.Module, | |
diffusion: GaussianDiffusion, | |
num_points: int, | |
point_dim: int = 3, | |
guidance_scale: float = 3.0, | |
clip_denoised: bool = True, | |
sigma_min: float = 1e-3, | |
sigma_max: float = 120, | |
s_churn: float = 3, | |
): | |
self.model = model | |
self.num_points = num_points | |
self.point_dim = point_dim | |
self.guidance_scale = guidance_scale | |
self.clip_denoised = clip_denoised | |
self.sigma_min = sigma_min | |
self.sigma_max = sigma_max | |
self.s_churn = s_churn | |
self.diffusion = diffusion | |
def sample_batch_progressive( | |
self, | |
batch_size: int, | |
condition: torch.Tensor, | |
noise=None, | |
device=None, | |
guidance_scale=None, | |
) -> Iterator[Dict[str, torch.Tensor]]: | |
""" | |
Generate samples progressively using classifier-free guidance. | |
Args: | |
batch_size: Number of samples to generate | |
condition: Conditioning tensor | |
noise: Optional initial noise tensor | |
device: Device to run on | |
guidance_scale: Optional override for guidance scale | |
Returns: | |
Iterator of dicts containing intermediate samples | |
""" | |
if guidance_scale is None: | |
guidance_scale = self.guidance_scale | |
sample_shape = (batch_size, self.point_dim, self.num_points) | |
# Double the batch for classifier-free guidance | |
if guidance_scale != 1 and guidance_scale != 0: | |
condition = torch.cat([condition, torch.zeros_like(condition)], dim=0) | |
if noise is not None: | |
noise = torch.cat([noise, noise], dim=0) | |
model_kwargs = {"condition": condition} | |
internal_batch_size = batch_size | |
if guidance_scale != 1 and guidance_scale != 0: | |
model = self._uncond_guide_model(self.model, guidance_scale) | |
internal_batch_size *= 2 | |
else: | |
model = self.model | |
samples_it = self.diffusion.ddim_sample_loop_progressive( | |
model, | |
shape=(internal_batch_size, *sample_shape[1:]), | |
model_kwargs=model_kwargs, | |
device=device, | |
clip_denoised=self.clip_denoised, | |
noise=noise, | |
) | |
for x in samples_it: | |
samples = { | |
"xstart": x["pred_xstart"][:batch_size], | |
"xprev": x["sample"][:batch_size] if "sample" in x else x["x"], | |
} | |
yield samples | |
def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module: | |
""" | |
Wraps the model for classifier-free guidance. | |
""" | |
def model_fn(x_t, ts, **kwargs): | |
half = x_t[: len(x_t) // 2] | |
combined = torch.cat([half, half], dim=0) | |
model_out = model(combined, ts, **kwargs) | |
eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :] | |
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) | |
half_eps = uncond_eps + scale * (cond_eps - uncond_eps) | |
eps = torch.cat([half_eps, half_eps], dim=0) | |
return torch.cat([eps, rest], dim=1) | |
return model_fn | |