Mariam-Elz commited on
Commit
6f85658
·
verified ·
1 Parent(s): b7eef6f

Upload pipelines.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipelines.py +212 -170
pipelines.py CHANGED
@@ -1,170 +1,212 @@
1
- import torch
2
- from libs.base_utils import do_resize_content
3
- from imagedream.ldm.util import (
4
- instantiate_from_config,
5
- get_obj_from_str,
6
- )
7
- from omegaconf import OmegaConf
8
- from PIL import Image
9
- import numpy as np
10
-
11
-
12
- class TwoStagePipeline(object):
13
- def __init__(
14
- self,
15
- stage1_model_config,
16
- stage2_model_config,
17
- stage1_sampler_config,
18
- stage2_sampler_config,
19
- device="cuda",
20
- dtype=torch.float16,
21
- resize_rate=1,
22
- ) -> None:
23
- """
24
- only for two stage generate process.
25
- - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
26
- - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
27
- """
28
- self.resize_rate = resize_rate
29
-
30
- self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
31
- self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False)
32
- self.stage1_model = self.stage1_model.to(device).to(dtype)
33
-
34
- self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
35
- sd = torch.load(stage2_model_config.resume, map_location="cpu")
36
- self.stage2_model.load_state_dict(sd, strict=False)
37
- self.stage2_model = self.stage2_model.to(device).to(dtype)
38
-
39
- self.stage1_model.device = device
40
- self.stage2_model.device = device
41
- self.device = device
42
- self.dtype = dtype
43
- self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)(
44
- self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params
45
- )
46
- self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)(
47
- self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params
48
- )
49
-
50
- def stage1_sample(
51
- self,
52
- pixel_img,
53
- prompt="3D assets",
54
- neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.",
55
- step=50,
56
- scale=5,
57
- ddim_eta=0.0,
58
- ):
59
- if type(pixel_img) == str:
60
- pixel_img = Image.open(pixel_img)
61
-
62
- if isinstance(pixel_img, Image.Image):
63
- if pixel_img.mode == "RGBA":
64
- background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
65
- pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
66
- else:
67
- pixel_img = pixel_img.convert("RGB")
68
- else:
69
- raise
70
- uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device)
71
- stage1_images = self.stage1_sampler.i2i(
72
- self.stage1_sampler.model,
73
- self.stage1_sampler.size,
74
- prompt,
75
- uc=uc,
76
- sampler=self.stage1_sampler.sampler,
77
- ip=pixel_img,
78
- step=step,
79
- scale=scale,
80
- batch_size=self.stage1_sampler.batch_size,
81
- ddim_eta=ddim_eta,
82
- dtype=self.stage1_sampler.dtype,
83
- device=self.stage1_sampler.device,
84
- camera=self.stage1_sampler.camera,
85
- num_frames=self.stage1_sampler.num_frames,
86
- pixel_control=(self.stage1_sampler.mode == "pixel"),
87
- transform=self.stage1_sampler.image_transform,
88
- offset_noise=self.stage1_sampler.offset_noise,
89
- )
90
-
91
- stage1_images = [Image.fromarray(img) for img in stage1_images]
92
- stage1_images.pop(self.stage1_sampler.ref_position)
93
- return stage1_images
94
-
95
- def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50):
96
- if type(pixel_img) == str:
97
- pixel_img = Image.open(pixel_img)
98
-
99
- if isinstance(pixel_img, Image.Image):
100
- if pixel_img.mode == "RGBA":
101
- background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
102
- pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
103
- else:
104
- pixel_img = pixel_img.convert("RGB")
105
- else:
106
- raise
107
- stage2_images = self.stage2_sampler.i2iStage2(
108
- self.stage2_sampler.model,
109
- self.stage2_sampler.size,
110
- "3D assets",
111
- self.stage2_sampler.uc,
112
- self.stage2_sampler.sampler,
113
- pixel_images=stage1_images,
114
- ip=pixel_img,
115
- step=step,
116
- scale=scale,
117
- batch_size=self.stage2_sampler.batch_size,
118
- ddim_eta=0.0,
119
- dtype=self.stage2_sampler.dtype,
120
- device=self.stage2_sampler.device,
121
- camera=self.stage2_sampler.camera,
122
- num_frames=self.stage2_sampler.num_frames,
123
- pixel_control=(self.stage2_sampler.mode == "pixel"),
124
- transform=self.stage2_sampler.image_transform,
125
- offset_noise=self.stage2_sampler.offset_noise,
126
- )
127
- stage2_images = [Image.fromarray(img) for img in stage2_images]
128
- return stage2_images
129
-
130
- def set_seed(self, seed):
131
- self.stage1_sampler.seed = seed
132
- self.stage2_sampler.seed = seed
133
-
134
- def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
135
- pixel_img = do_resize_content(pixel_img, self.resize_rate)
136
- stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
137
- stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step)
138
-
139
- return {
140
- "ref_img": pixel_img,
141
- "stage1_images": stage1_images,
142
- "stage2_images": stage2_images,
143
- }
144
-
145
-
146
- if __name__ == "__main__":
147
-
148
- stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config
149
- stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config
150
- stage2_sampler_config = stage2_config.sampler
151
- stage1_sampler_config = stage1_config.sampler
152
-
153
- stage1_model_config = stage1_config.models
154
- stage2_model_config = stage2_config.models
155
-
156
- pipeline = TwoStagePipeline(
157
- stage1_model_config,
158
- stage2_model_config,
159
- stage1_sampler_config,
160
- stage2_sampler_config,
161
- )
162
-
163
- img = Image.open("assets/astronaut.png")
164
- rt_dict = pipeline(img)
165
- stage1_images = rt_dict["stage1_images"]
166
- stage2_images = rt_dict["stage2_images"]
167
- np_imgs = np.concatenate(stage1_images, 1)
168
- np_xyzs = np.concatenate(stage2_images, 1)
169
- Image.fromarray(np_imgs).save("pixel_images.png")
170
- Image.fromarray(np_xyzs).save("xyz_images.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from libs.base_utils import do_resize_content
3
+ from imagedream.ldm.util import (
4
+ instantiate_from_config,
5
+ get_obj_from_str,
6
+ )
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ import PIL
10
+ import rembg
11
+ class TwoStagePipeline(object):
12
+ def __init__(
13
+ self,
14
+ stage1_model_config,
15
+ stage2_model_config,
16
+ stage1_sampler_config,
17
+ stage2_sampler_config,
18
+ device="cuda",
19
+ dtype=torch.float16,
20
+ resize_rate=1,
21
+ ) -> None:
22
+ """
23
+ only for two stage generate process.
24
+ - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config
25
+ - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config
26
+ """
27
+ self.resize_rate = resize_rate
28
+
29
+ self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model)
30
+ self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False)
31
+ self.stage1_model = self.stage1_model.to(device).to(dtype)
32
+
33
+ self.stage2_model = instantiate_from_config(OmegaConf.load(stage2_model_config.config).model)
34
+ sd = torch.load(stage2_model_config.resume, map_location="cpu")
35
+ self.stage2_model.load_state_dict(sd, strict=False)
36
+ self.stage2_model = self.stage2_model.to(device).to(dtype)
37
+
38
+ self.stage1_model.device = device
39
+ self.stage2_model.device = device
40
+ self.device = device
41
+ self.dtype = dtype
42
+ self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)(
43
+ self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params
44
+ )
45
+ self.stage2_sampler = get_obj_from_str(stage2_sampler_config.target)(
46
+ self.stage2_model, device=device, dtype=dtype, **stage2_sampler_config.params
47
+ )
48
+
49
+ def stage1_sample(
50
+ self,
51
+ pixel_img,
52
+ prompt="3D assets",
53
+ neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.",
54
+ step=50,
55
+ scale=5,
56
+ ddim_eta=0.0,
57
+ ):
58
+ if type(pixel_img) == str:
59
+ pixel_img = Image.open(pixel_img)
60
+
61
+ if isinstance(pixel_img, Image.Image):
62
+ if pixel_img.mode == "RGBA":
63
+ background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
64
+ pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
65
+ else:
66
+ pixel_img = pixel_img.convert("RGB")
67
+ else:
68
+ raise
69
+ uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device)
70
+ stage1_images = self.stage1_sampler.i2i(
71
+ self.stage1_sampler.model,
72
+ self.stage1_sampler.size,
73
+ prompt,
74
+ uc=uc,
75
+ sampler=self.stage1_sampler.sampler,
76
+ ip=pixel_img,
77
+ step=step,
78
+ scale=scale,
79
+ batch_size=self.stage1_sampler.batch_size,
80
+ ddim_eta=ddim_eta,
81
+ dtype=self.stage1_sampler.dtype,
82
+ device=self.stage1_sampler.device,
83
+ camera=self.stage1_sampler.camera,
84
+ num_frames=self.stage1_sampler.num_frames,
85
+ pixel_control=(self.stage1_sampler.mode == "pixel"),
86
+ transform=self.stage1_sampler.image_transform,
87
+ offset_noise=self.stage1_sampler.offset_noise,
88
+ )
89
+
90
+ stage1_images = [Image.fromarray(img) for img in stage1_images]
91
+ stage1_images.pop(self.stage1_sampler.ref_position)
92
+ return stage1_images
93
+
94
+ def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50):
95
+ if type(pixel_img) == str:
96
+ pixel_img = Image.open(pixel_img)
97
+
98
+ if isinstance(pixel_img, Image.Image):
99
+ if pixel_img.mode == "RGBA":
100
+ background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0))
101
+ pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB")
102
+ else:
103
+ pixel_img = pixel_img.convert("RGB")
104
+ else:
105
+ raise
106
+ stage2_images = self.stage2_sampler.i2iStage2(
107
+ self.stage2_sampler.model,
108
+ self.stage2_sampler.size,
109
+ "3D assets",
110
+ self.stage2_sampler.uc,
111
+ self.stage2_sampler.sampler,
112
+ pixel_images=stage1_images,
113
+ ip=pixel_img,
114
+ step=step,
115
+ scale=scale,
116
+ batch_size=self.stage2_sampler.batch_size,
117
+ ddim_eta=0.0,
118
+ dtype=self.stage2_sampler.dtype,
119
+ device=self.stage2_sampler.device,
120
+ camera=self.stage2_sampler.camera,
121
+ num_frames=self.stage2_sampler.num_frames,
122
+ pixel_control=(self.stage2_sampler.mode == "pixel"),
123
+ transform=self.stage2_sampler.image_transform,
124
+ offset_noise=self.stage2_sampler.offset_noise,
125
+ )
126
+ stage2_images = [Image.fromarray(img) for img in stage2_images]
127
+ return stage2_images
128
+
129
+ def set_seed(self, seed):
130
+ self.stage1_sampler.seed = seed
131
+ self.stage2_sampler.seed = seed
132
+
133
+ def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
134
+ pixel_img = do_resize_content(pixel_img, self.resize_rate)
135
+ stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
136
+ stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step)
137
+
138
+ return {
139
+ "ref_img": pixel_img,
140
+ "stage1_images": stage1_images,
141
+ "stage2_images": stage2_images,
142
+ }
143
+
144
+ rembg_session = rembg.new_session()
145
+
146
+ def expand_to_square(image, bg_color=(0, 0, 0, 0)):
147
+ # expand image to 1:1
148
+ width, height = image.size
149
+ if width == height:
150
+ return image
151
+ new_size = (max(width, height), max(width, height))
152
+ new_image = Image.new("RGBA", new_size, bg_color)
153
+ paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
154
+ new_image.paste(image, paste_position)
155
+ return new_image
156
+
157
+ def remove_background(
158
+ image: PIL.Image.Image,
159
+ rembg_session = None,
160
+ force: bool = False,
161
+ **rembg_kwargs,
162
+ ) -> PIL.Image.Image:
163
+ do_remove = True
164
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
165
+ # explain why current do not rm bg
166
+ print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
167
+ background = Image.new("RGBA", image.size, (0, 0, 0, 0))
168
+ image = Image.alpha_composite(background, image)
169
+ do_remove = False
170
+ do_remove = do_remove or force
171
+ if do_remove:
172
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
173
+ return image
174
+
175
+ def do_resize_content(original_image: Image, scale_rate):
176
+ # resize image content wile retain the original image size
177
+ if scale_rate != 1:
178
+ # Calculate the new size after rescaling
179
+ new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
180
+ # Resize the image while maintaining the aspect ratio
181
+ resized_image = original_image.resize(new_size)
182
+ # Create a new image with the original size and black background
183
+ padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
184
+ paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
185
+ padded_image.paste(resized_image, paste_position)
186
+ return padded_image
187
+ else:
188
+ return original_image
189
+
190
+ def add_background(image, bg_color=(255, 255, 255)):
191
+ # given an RGBA image, alpha channel is used as mask to add background color
192
+ background = Image.new("RGBA", image.size, bg_color)
193
+ return Image.alpha_composite(background, image)
194
+
195
+
196
+ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
197
+ """
198
+ input image is a pil image in RGBA, return RGB image
199
+ """
200
+ print(background_choice)
201
+ if background_choice == "Alpha as mask":
202
+ background = Image.new("RGBA", image.size, (0, 0, 0, 0))
203
+ image = Image.alpha_composite(background, image)
204
+ else:
205
+ image = remove_background(image, rembg_session, force_remove=True)
206
+ image = do_resize_content(image, foreground_ratio)
207
+ image = expand_to_square(image)
208
+ image = add_background(image, backgroud_color)
209
+ return image.convert("RGB")
210
+
211
+
212
+