ReconViaGen / trellis /pipelines /trellis_image_to_3d.py
Stable-X's picture
Update trellis/pipelines/trellis_image_to_3d.py
1dd8354 verified
from typing import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from PIL import Image
import trimesh
import os
import random
import trellis.modules.sparse as sp
from trellis.models.sparse_structure_vae import *
from contextlib import contextmanager
import sys
sys.path.append("wheels/vggt")
from wheels.vggt.vggt.models.vggt import VGGT
from typing import *
from scipy.spatial.transform import Rotation
from transformers import AutoModelForImageSegmentation
import rembg
def export_point_cloud(xyz, color):
# Convert tensors to numpy arrays if needed
if isinstance(xyz, torch.Tensor):
xyz = xyz.detach().cpu().numpy()
if isinstance(color, torch.Tensor):
color = color.detach().cpu().numpy()
color = (color * 255).astype(np.uint8)
# Create point cloud using trimesh
point_cloud = trimesh.PointCloud(vertices=xyz, colors=color)
return point_cloud
def normalize_trimesh(mesh):
# Calculate the mesh centroid and bounding box extents
centroid = mesh.centroid
# Determine the scale based on the largest extent to fit into unit cube
# Normalizing: Center and scale the vertices
mesh.vertices -= centroid
extents = mesh.extents
scale = max(extents)
mesh.vertices /= scale
return mesh
def random_sample_rotation(rotation_factor: float = 1.0) -> np.ndarray:
# angle_z, angle_y, angle_x
euler = np.random.rand(3) * np.pi * 2 / rotation_factor # (0, 2 * pi / rotation_range)
rotation = Rotation.from_euler('zyx', euler).as_matrix()
return rotation
from scipy.ndimage import binary_dilation
def voxelize_trimesh(mesh, resolution=(64, 64, 64), stride=4):
"""
Voxelize a given trimesh object with the specified resolution, incorporating 4x anti-aliasing.
First voxelizes at a 4x resolution and then downsamples to the target resolution.
Args:
mesh (trimesh.Trimesh): The input trimesh object to be voxelized.
resolution (tuple): The voxel grid resolution as (x, y, z). Default is (64, 64, 64).
Returns:
np.ndarray: A boolean numpy array representing the voxel grid where True indicates
the presence of the mesh in that voxel and False otherwise.
"""
target_density = max(resolution)
target_edge_length = 1.0 / target_density
max_edge_for_subdivision = target_edge_length / 2
# Calculate the higher resolution for 4x anti-aliasing
anti_aliasing_density = target_density * stride
anti_aliasing_edge_length = 1.0 / anti_aliasing_density
anti_aliasing_max_edge_for_subdivision = anti_aliasing_edge_length / 2
# Get the vertices and faces of the mesh
vertices = mesh.vertices
faces = mesh.faces
# Subdivide the mesh for the higher resolution voxelization
try:
new_vertices, new_faces = trimesh.remesh.subdivide_to_size(
vertices, faces, anti_aliasing_max_edge_for_subdivision
)
subdivided_mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces)
except Exception as e:
print(f"Unexpected error during mesh subdivision for anti-aliasing: {e}")
raise
# Voxelize the subdivided mesh at the higher resolution
try:
high_res_voxel_grid = subdivided_mesh.voxelized(
pitch=anti_aliasing_edge_length, method="binvox", exact=True
)
except:
print("Voxelization using 'binvox' method failed for anti-aliasing")
high_res_voxel_grid = subdivided_mesh.voxelized(pitch=anti_aliasing_edge_length)
print("Falling back to default voxelization method for anti-aliasing.")
high_res_boolean_array = high_res_voxel_grid.matrix.astype(bool)
x_stride, y_stride, z_stride = [int(anti_aliasing_density / target_density)] * 3
downsampled_shape = (
high_res_boolean_array.shape[0] // x_stride,
high_res_boolean_array.shape[1] // y_stride,
high_res_boolean_array.shape[2] // z_stride
)
downsampled_array = np.zeros(downsampled_shape, dtype=bool)
# Use NumPy's strided tricks to efficiently access sub-cubes for downsampling
shape = (downsampled_shape[0], downsampled_shape[1], downsampled_shape[2], x_stride, y_stride, z_stride)
strides = (x_stride * high_res_boolean_array.strides[0],
y_stride * high_res_boolean_array.strides[1],
z_stride * high_res_boolean_array.strides[2],
high_res_boolean_array.strides[0],
high_res_boolean_array.strides[1],
high_res_boolean_array.strides[2])
sub_cubes = np.lib.stride_tricks.as_strided(high_res_boolean_array, shape=shape, strides=strides)
downsampled_array = np.any(sub_cubes, axis=(3, 4, 5))
return downsampled_array
def get_occupied_coordinates(voxel_grid):
# Find the indices of occupied voxels
occupied_indices = np.argwhere(voxel_grid)
coords = torch.tensor(occupied_indices, dtype=torch.int8) # Use float for scaling operations
# Add a leading dimension for batch size or any additional data associations
coords = torch.cat([torch.zeros(coords.shape[0], 1, dtype=torch.int32), coords + 1], dim=1)
# Move to GPU if required
coords = coords.to('cuda:0')
return coords
from .base import Pipeline
from . import samplers
from ..modules import sparse as sp
class TrellisImageTo3DPipeline(Pipeline):
"""
Pipeline for inferring Trellis image-to-3D models.
Args:
models (dict[str, nn.Module]): The models to use in the pipeline.
sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure.
slat_sampler (samplers.Sampler): The sampler for the structured latent.
slat_normalization (dict): The normalization parameters for the structured latent.
image_cond_model (str): The name of the image conditioning model.
"""
default_image_resolution = 518
def __init__(
self,
models: dict[str, nn.Module] = None,
sparse_structure_sampler: samplers.Sampler = None,
slat_sampler: samplers.Sampler = None,
slat_normalization: dict = None,
image_cond_model: str = None,
):
if models is None:
return
super().__init__(models)
self.sparse_structure_sampler = sparse_structure_sampler
self.slat_sampler = slat_sampler
self.sparse_structure_sampler_params = {}
self.slat_sampler_params = {}
self.slat_normalization = slat_normalization
self._init_image_cond_model(image_cond_model)
@staticmethod
def from_pretrained(path: str) -> "TrellisImageTo3DPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path)
new_pipeline = TrellisImageTo3DPipeline()
new_pipeline.__dict__ = pipeline.__dict__
args = pipeline._pretrained_args
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
new_pipeline.slat_sampler_params = args['slat_sampler']['params']
new_pipeline.slat_normalization = args['slat_normalization']
new_pipeline._init_image_cond_model(args['image_cond_model'])
return new_pipeline
def _init_image_cond_model(self, name: str):
"""
Initialize the image conditioning model.
"""
try:
dinov2_model = torch.hub.load(os.path.join(torch.hub.get_dir(), 'facebookresearch_dinov2_main'), name, source='local',pretrained=True)
except:
dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True)
dinov2_model.eval()
self.models['image_cond_model'] = dinov2_model
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.image_cond_model_transform = transform
def preprocess_image(self, input: Image.Image, resolution=518, no_background=True, recenter=True) -> Image.Image:
"""
Preprocess the input image using BiRefNet for background removal.
Includes padding to maintain aspect ratio when resizing to 518x518.
"""
# if has alpha channel, use it directly
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, -1]
if not np.all(alpha == 255):
has_alpha = True
if has_alpha:
output = input
else:
input = input.convert('RGB')
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
# Get mask using BiRefNet
mask = self._get_birefnet_mask(input)
# Convert input to RGBA and apply mask
input_rgba = input.convert('RGBA')
input_array = np.array(input_rgba)
input_array[:, :, 3] = mask * 255 # Apply mask to alpha channel
output = Image.fromarray(input_array)
# Process the output image
output_np = np.array(output)
alpha = output_np[:, :, 3]
# Find bounding box of non-transparent pixels
bbox = np.argwhere(alpha > 0.8 * 255)
if len(bbox) == 0: # Handle case where no foreground is detected
return input.convert('RGB')
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1.1)
height, width = alpha.shape
if not recenter:
center = [width / 2, height / 2]
size = max(bbox[2] - bbox[0],
bbox[3] - bbox[1],
(bbox[2] - width / 2) * 2,
(width / 2 - bbox[0]) * 2,
(height / 2 - bbox[1]) * 2,
(bbox[3] - height / 2) * 2)
# Calculate and apply crop bbox
if not no_background:
if height > width:
center[0] = width / 2
if center[1] < width / 2:
center[1] = width / 2
elif center[1] > height - width / 2:
center[1] = height - width / 2
else:
center[1] = height / 2
if center[0] < height / 2:
center[0] = height / 2
elif center[0] > width - height / 2:
center[0] = width - height / 2
size = min(center[0], center[1], input.width - center[0], input.height - center[1], size) * 2
bbox = (
int(center[0] - size // 2),
int(center[1] - size // 2),
int(center[0] + size // 2),
int(center[1] + size // 2)
)
# Ensure bbox is within image bounds
bbox = (
max(0, bbox[0]),
max(0, bbox[1]),
min(output.width, bbox[2]),
min(output.height, bbox[3])
)
output = output.crop(bbox)
# Add padding to maintain aspect ratio
width, height = output.size
if width > height:
new_height = width
padding = (width - height) // 2
padded_output = Image.new('RGBA', (width, new_height), (0, 0, 0, 0))
padded_output.paste(output, (0, padding))
else:
new_width = height
padding = (height - width) // 2
padded_output = Image.new('RGBA', (new_width, height), (0, 0, 0, 0))
padded_output.paste(output, (padding, 0))
# Resize padded image to target size
# padded_output = padded_output.resize((resolution, resolution), Image.Resampling.LANCZOS)
padded_output = torch.from_numpy(np.array(padded_output).astype(np.float32)) / 255
padded_output = F.interpolate(padded_output.unsqueeze(0).permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False)[0].permute(1, 2, 0)
# Final processing
output = padded_output.cpu().numpy()
if no_background:
output = np.dstack((
output[:, :, :3] * (output[:, :, 3:4] > 0.8), # RGB channels premultiplied by alpha
output[:, :, 3] # Original alpha channel
))
output = Image.fromarray((output * 255).astype(np.uint8), mode='RGBA')
return output
def _get_birefnet_mask(self, image: Image.Image) -> np.ndarray:
"""Get object mask using BiRefNet"""
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(image).unsqueeze(0).to(self.device)
with torch.no_grad():
preds = self.birefnet_model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image.size)
mask_np = np.array(mask)
return (mask_np > 128).astype(np.uint8)
@torch.no_grad()
def encode_image(self, image: Union[torch.Tensor, list[Image.Image]], w_layernorm=True) -> torch.Tensor:
"""
Encode the image.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image to encode
Returns:
torch.Tensor: The encoded features.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
image = F.interpolate(image, self.default_image_resolution, mode='bilinear', align_corners=False)
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((self.default_image_resolution, self.default_image_resolution), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).to(self.device)
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
image = self.image_cond_model_transform(image).to(self.device)
features = self.models['image_cond_model'](image, is_training=True)['x_prenorm']
if w_layernorm:
features = F.layer_norm(features, features.shape[-1:])
return features
def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: The conditioning information
"""
cond = self.encode_image(image)
neg_cond = torch.zeros_like(cond)
return {
'cond': cond,
'neg_cond': neg_cond,
}
def sample_sparse_structure(
self,
cond: dict,
num_samples: int = 1,
sampler_params: dict = {},
noise: torch.Tensor = None,
) -> torch.Tensor:
"""
Sample sparse structures with the given conditioning.
Args:
cond (dict): The conditioning information.
num_samples (int): The number of samples to generate.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample occupancy latent
flow_model = self.models['sparse_structure_flow_model']
reso = flow_model.resolution
if noise is None:
noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
z_s = self.sparse_structure_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True
).samples
# Decode occupancy latent
decoder = self.models['sparse_structure_decoder']
coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
return coords
def encode_slat(
self,
slat: sp.SparseTensor,
):
ret = {}
slat = self.models['slat_encoder'](slat, sample_posterior=False)
ret['slat'] = slat
return ret
@torch.no_grad()
def decode_slat(
self,
slat: sp.SparseTensor,
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
) -> dict:
"""
Decode the structured latent.
Args:
slat (sp.SparseTensor): The structured latent.
formats (List[str]): The formats to decode the structured latent to.
Returns:
dict: The decoded structured latent.
"""
ret = {}
ret['slat'] = slat
if 'gaussian' in formats:
ret['gaussian'] = self.models['slat_decoder_gs'](slat)
if 'mesh' in formats:
ret['mesh'] = self.models['slat_decoder_mesh'](slat)
if 'radiance_field' in formats:
ret['radiance_field'] = self.models['slat_decoder_rf'](slat)
return ret
def sample_slat(
self,
cond: dict,
coords: torch.Tensor,
sampler_params: dict = {},
) -> sp.SparseTensor:
"""
Sample structured latent with the given conditioning.
Args:
cond (dict): The conditioning information.
coords (torch.Tensor): The coordinates of the sparse structure.
sampler_params (dict): Additional parameters for the sampler.
"""
# Sample structured latent
flow_model = self.models['slat_flow_model']
noise = sp.SparseTensor(
feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
coords=coords,
)
sampler_params = {**self.slat_sampler_params, **sampler_params}
slat = self.slat_sampler.sample(
flow_model,
noise,
**cond,
**sampler_params,
verbose=True
).samples
std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device)
mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device)
slat = slat * std + mean
return slat
def get_input(self, batch_data):
std = torch.tensor(self.slat_normalization['std'])[None].to(self.device)
mean = torch.tensor(self.slat_normalization['mean'])[None].to(self.device)
images = batch_data['source_image']
cond = self.encode_image(images)
if random.random() > 0.5:
cond = torch.zeros_like(cond)
target_feats = batch_data['target_feats']
target_coords = batch_data['target_coords']
targets = sp.SparseTensor(target_feats, target_coords).to(self.device)
targets = (targets - mean) / std
noise = sp.SparseTensor(
feats=torch.randn_like(target_feats).to(self.device),
coords=target_coords.to(self.device),
)
return targets, cond, noise
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
return self.slat_flow_model(x, t, cond)
@contextmanager
def inject_sampler_multi_image(
self,
sampler_name: str,
num_images: int,
num_steps: int,
mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
):
"""
Inject a sampler with multiple images as condition.
Args:
sampler_name (str): The name of the sampler to inject.
num_images (int): The number of images to condition on.
num_steps (int): The number of steps to run the sampler for.
"""
sampler = getattr(self, sampler_name)
setattr(sampler, f'_old_inference_model', sampler._inference_model)
if mode == 'stochastic':
if num_images > num_steps:
print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. "
"This may lead to performance degradation.\033[0m")
cond_indices = (np.arange(num_steps) % num_images).tolist()
def _new_inference_model(self, model, x_t, t, cond, **kwargs):
cond_idx = cond_indices.pop(0)
cond_i = cond[cond_idx:cond_idx+1]
return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs)
elif mode =='multidiffusion':
from .samplers import FlowEulerSampler
def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs):
if cfg_interval[0] <= t <= cfg_interval[1]:
preds = []
for i in range(len(cond)):
preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
pred = sum(preds) / len(preds)
neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
return (1 + cfg_strength) * pred - cfg_strength * neg_pred
else:
preds = []
for i in range(len(cond)):
preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs))
pred = sum(preds) / len(preds)
return pred
else:
raise ValueError(f"Unsupported mode: {mode}")
sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler))
yield
sampler._inference_model = sampler._old_inference_model
delattr(sampler, f'_old_inference_model')
@torch.no_grad()
def run_multi_image(
self,
images: List[Image.Image],
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
slat_sampler_params: dict = {},
formats: List[str] = ['mesh', 'gaussian', 'radiance_field'],
preprocess_image: bool = True,
mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
):
"""
Run the pipeline with multiple images as condition
Args:
images (List[Image.Image]): The multi-view images of the assets
num_samples (int): The number of samples to generate.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
preprocess_image (bool): Whether to preprocess the image.
"""
if preprocess_image:
images = [self.preprocess_image(image) for image in images]
cond = self.get_cond(images)
cond['neg_cond'] = cond['neg_cond'][:1]
torch.manual_seed(seed)
flow_model = self.models['sparse_structure_flow_model']
reso = flow_model.resolution
noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps')
with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode):
coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params, noise)
slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps')
with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode):
slat = self.sample_slat(cond, coords, slat_sampler_params)
return self.decode_slat(slat, formats)
@torch.no_grad()
def run(
self,
image: Image.Image,
ref_image: Image.Image = None,
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
slat_sampler_params: dict = {},
formats: List[str] = ['mesh'],
preprocess_image: bool = True,
init_mesh: trimesh.Trimesh = None,
coords: torch.Tensor = None,
normalize_init_mesh: bool = False,
init_resolution: int = 62,
init_stride: int = 4
) -> dict:
"""
Run the pipeline.
Args:
image (Image.Image): The image prompt.
num_samples (int): The number of samples to generate.
sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler.
slat_sampler_params (dict): Additional parameters for the structured latent sampler.
preprocess_image (bool): Whether to preprocess the image.
"""
if preprocess_image:
image = self.preprocess_image(image)
if ref_image is not None:
cond = self.encode_image([image, ref_image])
neg_cond = torch.zeros_like(cond[0:1])
sparse_cond = slat_cond = {
'cond': 0.5 * cond[0:1] + 0.5 * cond[1:2],
'neg_cond': neg_cond,
}
else:
sparse_cond = slat_cond = self.get_cond([image])
torch.manual_seed(seed)
if coords is not None:
coords = coords
else:
coords = self.sample_sparse_structure(sparse_cond, num_samples, sparse_structure_sampler_params)
slat = self.sample_slat(slat_cond, coords, slat_sampler_params)
return self.decode_slat(slat, formats)
def configure_optimizers(self):
params = list(self.slat_flow_model.parameters())
opt = torch.optim.AdamW(params, lr=1e-4, weight_decay=0.0)
return opt
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline):
def get_ss_cond(self, image_cond: torch.Tensor, aggregated_tokens_list: List, num_samples: int) -> dict:
"""
Get the conditioning information for the model.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image prompts.
Returns:
dict: The conditioning information
"""
cond = self.sparse_structure_vggt_cond(aggregated_tokens_list, image_cond)
neg_cond = torch.zeros_like(cond)
return {
'cond': cond,
'neg_cond': neg_cond,
}
@torch.no_grad()
def vggt_feat(self, image: Union[torch.Tensor, list[Image.Image]]) -> List:
"""
Encode the image.
Args:
image (Union[torch.Tensor, list[Image.Image]]): The image to encode
Returns:
torch.Tensor: The encoded features.
"""
if isinstance(image, torch.Tensor):
assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)"
image = F.interpolate(image, self.default_image_resolution, mode='bilinear', align_corners=False)
elif isinstance(image, list):
assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images"
image = [i.resize((self.default_image_resolution, self.default_image_resolution), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).to(self.device)
else:
raise ValueError(f"Unsupported type of image: {type(image)}")
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=self.VGGT_dtype):
# Predict attributes including cameras, depth maps, and point maps.
aggregated_tokens_list, _ = self.VGGT_model.aggregator(image[None])
return aggregated_tokens_list, image
def run(
self,
image: Union[torch.Tensor, list[Image.Image]],
coords: torch.Tensor = None,
num_samples: int = 1,
seed: int = 42,
sparse_structure_sampler_params: dict = {},
slat_sampler_params: dict = {},
formats: List[str] = ['mesh'],
preprocess_image: bool = True,
mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
):
torch.manual_seed(seed)
aggregated_tokens_list, _ = self.vggt_feat(image)
b, n, _, _ = aggregated_tokens_list[0].shape
image_cond = self.encode_image(image).reshape(b, n, -1, 1024)
# if coords is None:
ss_flow_model = self.models['sparse_structure_flow_model']
ss_cond = self.get_ss_cond(image_cond[:, :, 5:], aggregated_tokens_list, num_samples)
# Sample structured latent
ss_sampler_params = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}
reso = ss_flow_model.resolution
ss_noise = torch.randn(num_samples, ss_flow_model.in_channels, reso, reso, reso).to(self.device)
ss_slat = self.sparse_structure_sampler.sample(
ss_flow_model,
ss_noise,
**ss_cond,
**ss_sampler_params,
verbose=True
).samples
decoder = self.models['sparse_structure_decoder']
coords = torch.argwhere(decoder(ss_slat)>0)[:, [0, 2, 3, 4]].int()
cond = {
'cond': image_cond.reshape(n, -1, 1024),
'neg_cond': torch.zeros_like(image_cond.reshape(n, -1, 1024))[:1],
}
slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps')
with self.inject_sampler_multi_image('slat_sampler', len(image), slat_steps, mode=mode):
slat = self.sample_slat(cond, coords, slat_sampler_params)
return self.decode_slat(slat, formats)
@staticmethod
def from_pretrained(path: str) -> "TrellisVGGTTo3DPipeline":
"""
Load a pretrained model.
Args:
path (str): The path to the model. Can be either local path or a Hugging Face repository.
"""
pipeline = super(TrellisVGGTTo3DPipeline, TrellisVGGTTo3DPipeline).from_pretrained(path)
new_pipeline = TrellisVGGTTo3DPipeline()
new_pipeline.__dict__ = pipeline.__dict__
args = pipeline._pretrained_args
new_pipeline.VGGT_dtype = torch.float32
VGGT_model = VGGT.from_pretrained("Stable-X/vggt-object-v0-1")
new_pipeline.VGGT_model = VGGT_model.to(new_pipeline.device)
del new_pipeline.VGGT_model.depth_head
del new_pipeline.VGGT_model.track_head
del new_pipeline.VGGT_model.camera_head
del new_pipeline.VGGT_model.point_head
new_pipeline.VGGT_model.eval()
new_pipeline.birefnet_model = AutoModelForImageSegmentation.from_pretrained(
'ZhengPeng7/BiRefNet',
trust_remote_code=True
).to(new_pipeline.device)
new_pipeline.birefnet_model.eval()
new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args'])
new_pipeline.slat_sampler_params = args['slat_sampler']['params']
new_pipeline.slat_normalization = args['slat_normalization']
new_pipeline._init_image_cond_model(args['image_cond_model'])
return new_pipeline