Spaces:
Running
on
L4
Running
on
L4
| # -------------------------------------------------------- | |
| # Adapted from: https://github.com/openai/point-e | |
| # Licensed under the MIT License | |
| # Copyright (c) 2022 OpenAI | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # -------------------------------------------------------- | |
| from typing import Dict, Iterator | |
| import torch | |
| import torch.nn as nn | |
| from .gaussian_diffusion import GaussianDiffusion | |
| class PointCloudSampler: | |
| """ | |
| A wrapper around a model that produces conditional sample tensors. | |
| """ | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| diffusion: GaussianDiffusion, | |
| num_points: int, | |
| point_dim: int = 3, | |
| guidance_scale: float = 3.0, | |
| clip_denoised: bool = True, | |
| sigma_min: float = 1e-3, | |
| sigma_max: float = 120, | |
| s_churn: float = 3, | |
| ): | |
| self.model = model | |
| self.num_points = num_points | |
| self.point_dim = point_dim | |
| self.guidance_scale = guidance_scale | |
| self.clip_denoised = clip_denoised | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.s_churn = s_churn | |
| self.diffusion = diffusion | |
| def sample_batch_progressive( | |
| self, | |
| batch_size: int, | |
| condition: torch.Tensor, | |
| noise=None, | |
| device=None, | |
| guidance_scale=None, | |
| ) -> Iterator[Dict[str, torch.Tensor]]: | |
| """ | |
| Generate samples progressively using classifier-free guidance. | |
| Args: | |
| batch_size: Number of samples to generate | |
| condition: Conditioning tensor | |
| noise: Optional initial noise tensor | |
| device: Device to run on | |
| guidance_scale: Optional override for guidance scale | |
| Returns: | |
| Iterator of dicts containing intermediate samples | |
| """ | |
| if guidance_scale is None: | |
| guidance_scale = self.guidance_scale | |
| sample_shape = (batch_size, self.point_dim, self.num_points) | |
| # Double the batch for classifier-free guidance | |
| if guidance_scale != 1 and guidance_scale != 0: | |
| condition = torch.cat([condition, torch.zeros_like(condition)], dim=0) | |
| if noise is not None: | |
| noise = torch.cat([noise, noise], dim=0) | |
| model_kwargs = {"condition": condition} | |
| internal_batch_size = batch_size | |
| if guidance_scale != 1 and guidance_scale != 0: | |
| model = self._uncond_guide_model(self.model, guidance_scale) | |
| internal_batch_size *= 2 | |
| else: | |
| model = self.model | |
| samples_it = self.diffusion.ddim_sample_loop_progressive( | |
| model, | |
| shape=(internal_batch_size, *sample_shape[1:]), | |
| model_kwargs=model_kwargs, | |
| device=device, | |
| clip_denoised=self.clip_denoised, | |
| noise=noise, | |
| ) | |
| for x in samples_it: | |
| samples = { | |
| "xstart": x["pred_xstart"][:batch_size], | |
| "xprev": x["sample"][:batch_size] if "sample" in x else x["x"], | |
| } | |
| yield samples | |
| def _uncond_guide_model(self, model: nn.Module, scale: float) -> nn.Module: | |
| """ | |
| Wraps the model for classifier-free guidance. | |
| """ | |
| def model_fn(x_t, ts, **kwargs): | |
| half = x_t[: len(x_t) // 2] | |
| combined = torch.cat([half, half], dim=0) | |
| model_out = model(combined, ts, **kwargs) | |
| eps, rest = model_out[:, : self.point_dim], model_out[:, self.point_dim :] | |
| cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0) | |
| half_eps = uncond_eps + scale * (cond_eps - uncond_eps) | |
| eps = torch.cat([half_eps, half_eps], dim=0) | |
| return torch.cat([eps, rest], dim=1) | |
| return model_fn | |