PixCell
Collection
PixCell models. More info at https://histodiffusion.github.io/docs/projects/pixcell/
β’
7 items
β’
Updated
[π arXiv][π¬ PixCell-1024] [π¬ PixCell-256] [π¬ Pixcell-256-Cell-ControlNet] [πΎ Synthetic SBU-1M]
import torch
from diffusers import DiffusionPipeline
from diffusers import AutoencoderKL
device = torch.device('cuda')
# We do not host the weights of the SD3 VAE -- load it from StabilityAI
sd3_vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3.5-large", subfolder="vae")
pipeline = DiffusionPipeline.from_pretrained(
"StonyBrook-CVLab/PixCell-256",
vae=sd3_vae,
custom_pipeline="StonyBrook-CVLab/PixCell-pipeline",
trust_remote_code=True,
torch_dtype=torch.float16,
)
pipeline.to(device);
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
timm_kwargs = {
'img_size': 224,
'patch_size': 14,
'depth': 24,
'num_heads': 24,
'init_values': 1e-5,
'embed_dim': 1536,
'mlp_ratio': 2.66667*2,
'num_classes': 0,
'no_embed_class': True,
'mlp_layer': timm.layers.SwiGLUPacked,
'act_layer': torch.nn.SiLU,
'reg_tokens': 8,
'dynamic_img_size': True
}
uni_model = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))
uni_model.eval()
uni_model.to(device);
uncond = pipeline.get_unconditional_embedding(1)
with torch.amp.autocast('cuda'):
samples = pipeline(uni_embeds=uncond, negative_uni_embeds=None, guidance_scale=1.0)
# Load image
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
# This is an example image we provide
path = hf_hub_download(repo_id="StonyBrook-CVLab/PixCell-256", filename="test_image.png")
image = Image.open(path).convert("RGB")
# Extract UNI embedding from the image
uni_inp = transform(image).unsqueeze(dim=0)
with torch.inference_mode():
uni_emb = uni_model(uni_inp.to(device))
# reshape UNI to (bs, 1, D)
uni_emb = uni_emb.unsqueeze(1)
print("Extracted UNI:", uni_emb.shape)
# Get unconditional embedding for classifier-free guidance
uncond = pipeline.get_unconditional_embedding(uni_emb.shape[0])
# Generate new samples
with torch.amp.autocast('cuda'):
samples = pipeline(uni_embeds=uni_emb, negative_uni_embeds=uncond, guidance_scale=3., num_images_per_prompt=1).images