Mariam-Elz commited on
Commit
c5c60bf
·
verified ·
1 Parent(s): d9b9206

Upload pipelines.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipelines.py +170 -0
pipelines.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from libs.base_utils import do_resize_content
3
+ from imagedream.ldm.util import (
4
+ instantiate_from_config,
5
+ get_obj_from_str,
6
+ )
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+
12
+ class TwoStagePipeline(object):
13
+ def __init__(
14
+ self,
15
+ stage1_model_config,
16
+ stage2_model_config,
17
+ stage1_sampler_config,
18
+ stage2_sampler_config,
19
+ device="cuda",
20
+ dtype=torch.float16,
21
+ resize_rate=1,
22
+ ) -> None:
23
+ """
24
+ only for two stage generate process.
25
+ - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
26
+ - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
27
+ """
28
+ self.resize_rate = resize_rate
29
+
30
+ self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
31
+ self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False)
32
+ self.stage1_model = self.stage1_model.to(device).to(dtype)
33
+
34
+ self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
35
+ sd = torch.load(stage2_model_config.resume, map_location="cpu")
36
+ self.stage2_model.load_state_dict(sd, strict=False)
37
+ self.stage2_model = self.stage2_model.to(device).to(dtype)
38
+
39
+ self.stage1_model.device = device
40
+ self.stage2_model.device = device
41
+ self.device = device
42
+ self.dtype = dtype
43
+ self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)(
44
+ self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params
45
+ )
46
+ self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)(
47
+ self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params
48
+ )
49
+
50
+ def stage1_sample(
51
+ self,
52
+ pixel_img,
53
+ prompt="3D assets",
54
+ neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.",
55
+ step=50,
56
+ scale=5,
57
+ ddim_eta=0.0,
58
+ ):
59
+ if type(pixel_img) == str:
60
+ pixel_img = Image.open(pixel_img)
61
+
62
+ if isinstance(pixel_img, Image.Image):
63
+ if pixel_img.mode == "RGBA":
64
+ background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
65
+ pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
66
+ else:
67
+ pixel_img = pixel_img.convert("RGB")
68
+ else:
69
+ raise
70
+ uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device)
71
+ stage1_images = self.stage1_sampler.i2i(
72
+ self.stage1_sampler.model,
73
+ self.stage1_sampler.size,
74
+ prompt,
75
+ uc=uc,
76
+ sampler=self.stage1_sampler.sampler,
77
+ ip=pixel_img,
78
+ step=step,
79
+ scale=scale,
80
+ batch_size=self.stage1_sampler.batch_size,
81
+ ddim_eta=ddim_eta,
82
+ dtype=self.stage1_sampler.dtype,
83
+ device=self.stage1_sampler.device,
84
+ camera=self.stage1_sampler.camera,
85
+ num_frames=self.stage1_sampler.num_frames,
86
+ pixel_control=(self.stage1_sampler.mode == "pixel"),
87
+ transform=self.stage1_sampler.image_transform,
88
+ offset_noise=self.stage1_sampler.offset_noise,
89
+ )
90
+
91
+ stage1_images = [Image.fromarray(img) for img in stage1_images]
92
+ stage1_images.pop(self.stage1_sampler.ref_position)
93
+ return stage1_images
94
+
95
+ def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50):
96
+ if type(pixel_img) == str:
97
+ pixel_img = Image.open(pixel_img)
98
+
99
+ if isinstance(pixel_img, Image.Image):
100
+ if pixel_img.mode == "RGBA":
101
+ background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
102
+ pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
103
+ else:
104
+ pixel_img = pixel_img.convert("RGB")
105
+ else:
106
+ raise
107
+ stage2_images = self.stage2_sampler.i2iStage2(
108
+ self.stage2_sampler.model,
109
+ self.stage2_sampler.size,
110
+ "3D assets",
111
+ self.stage2_sampler.uc,
112
+ self.stage2_sampler.sampler,
113
+ pixel_images=stage1_images,
114
+ ip=pixel_img,
115
+ step=step,
116
+ scale=scale,
117
+ batch_size=self.stage2_sampler.batch_size,
118
+ ddim_eta=0.0,
119
+ dtype=self.stage2_sampler.dtype,
120
+ device=self.stage2_sampler.device,
121
+ camera=self.stage2_sampler.camera,
122
+ num_frames=self.stage2_sampler.num_frames,
123
+ pixel_control=(self.stage2_sampler.mode == "pixel"),
124
+ transform=self.stage2_sampler.image_transform,
125
+ offset_noise=self.stage2_sampler.offset_noise,
126
+ )
127
+ stage2_images = [Image.fromarray(img) for img in stage2_images]
128
+ return stage2_images
129
+
130
+ def set_seed(self, seed):
131
+ self.stage1_sampler.seed = seed
132
+ self.stage2_sampler.seed = seed
133
+
134
+ def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
135
+ pixel_img = do_resize_content(pixel_img, self.resize_rate)
136
+ stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
137
+ stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step)
138
+
139
+ return {
140
+ "ref_img": pixel_img,
141
+ "stage1_images": stage1_images,
142
+ "stage2_images": stage2_images,
143
+ }
144
+
145
+
146
+ if __name__ == "__main__":
147
+
148
+ stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config
149
+ stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config
150
+ stage2_sampler_config = stage2_config.sampler
151
+ stage1_sampler_config = stage1_config.sampler
152
+
153
+ stage1_model_config = stage1_config.models
154
+ stage2_model_config = stage2_config.models
155
+
156
+ pipeline = TwoStagePipeline(
157
+ stage1_model_config,
158
+ stage2_model_config,
159
+ stage1_sampler_config,
160
+ stage2_sampler_config,
161
+ )
162
+
163
+ img = Image.open("assets/astronaut.png")
164
+ rt_dict = pipeline(img)
165
+ stage1_images = rt_dict["stage1_images"]
166
+ stage2_images = rt_dict["stage2_images"]
167
+ np_imgs = np.concatenate(stage1_images, 1)
168
+ np_xyzs = np.concatenate(stage2_images, 1)
169
+ Image.fromarray(np_imgs).save("pixel_images.png")
170
+ Image.fromarray(np_xyzs).save("xyz_images.png")