Mariam-Elz commited on
Commit
94fd648
·
verified ·
1 Parent(s): 4562a05

Upload run.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run.py +160 -0
run.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from libs.base_utils import do_resize_content
3
+ from imagedream.ldm.util import (
4
+ instantiate_from_config,
5
+ get_obj_from_str,
6
+ )
7
+ from omegaconf import OmegaConf
8
+ from PIL import Image
9
+ import numpy as np
10
+ from inference import generate3d
11
+ from huggingface_hub import hf_hub_download
12
+ import json
13
+ import argparse
14
+ import shutil
15
+ from model import CRM
16
+ import PIL
17
+ import rembg
18
+ import os
19
+ from pipelines import TwoStagePipeline
20
+
21
+ rembg_session = rembg.new_session()
22
+
23
+ def expand_to_square(image, bg_color=(0, 0, 0, 0)):
24
+ # expand image to 1:1
25
+ width, height = image.size
26
+ if width == height:
27
+ return image
28
+ new_size = (max(width, height), max(width, height))
29
+ new_image = Image.new("RGBA", new_size, bg_color)
30
+ paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
31
+ new_image.paste(image, paste_position)
32
+ return new_image
33
+
34
+ def remove_background(
35
+ image: PIL.Image.Image,
36
+ rembg_session = None,
37
+ force: bool = False,
38
+ **rembg_kwargs,
39
+ ) -> PIL.Image.Image:
40
+ do_remove = True
41
+ if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
42
+ # explain why current do not rm bg
43
+ print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
44
+ background = Image.new("RGBA", image.size, (0, 0, 0, 0))
45
+ image = Image.alpha_composite(background, image)
46
+ do_remove = False
47
+ do_remove = do_remove or force
48
+ if do_remove:
49
+ image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
50
+ return image
51
+
52
+ def do_resize_content(original_image: Image, scale_rate):
53
+ # resize image content wile retain the original image size
54
+ if scale_rate != 1:
55
+ # Calculate the new size after rescaling
56
+ new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
57
+ # Resize the image while maintaining the aspect ratio
58
+ resized_image = original_image.resize(new_size)
59
+ # Create a new image with the original size and black background
60
+ padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0))
61
+ paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2)
62
+ padded_image.paste(resized_image, paste_position)
63
+ return padded_image
64
+ else:
65
+ return original_image
66
+
67
+ def add_background(image, bg_color=(255, 255, 255)):
68
+ # given an RGBA image, alpha channel is used as mask to add background color
69
+ background = Image.new("RGBA", image.size, bg_color)
70
+ return Image.alpha_composite(background, image)
71
+
72
+
73
+ def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
74
+ """
75
+ input image is a pil image in RGBA, return RGB image
76
+ """
77
+ print(background_choice)
78
+ if background_choice == "Alpha as mask":
79
+ background = Image.new("RGBA", image.size, (0, 0, 0, 0))
80
+ image = Image.alpha_composite(background, image)
81
+ else:
82
+ image = remove_background(image, rembg_session, force_remove=True)
83
+ image = do_resize_content(image, foreground_ratio)
84
+ image = expand_to_square(image)
85
+ image = add_background(image, backgroud_color)
86
+ return image.convert("RGB")
87
+
88
+ if __name__ == "__main__":
89
+
90
+ parser = argparse.ArgumentParser()
91
+ parser.add_argument(
92
+ "--inputdir",
93
+ type=str,
94
+ default="examples/kunkun.webp",
95
+ help="dir for input image",
96
+ )
97
+ parser.add_argument(
98
+ "--scale",
99
+ type=float,
100
+ default=5.0,
101
+ )
102
+ parser.add_argument(
103
+ "--step",
104
+ type=int,
105
+ default=50,
106
+ )
107
+ parser.add_argument(
108
+ "--bg_choice",
109
+ type=str,
110
+ default="Auto Remove background",
111
+ help="[Auto Remove background] or [Alpha as mask]",
112
+ )
113
+ parser.add_argument(
114
+ "--outdir",
115
+ type=str,
116
+ default="out/",
117
+ )
118
+ args = parser.parse_args()
119
+
120
+
121
+ img = Image.open(args.inputdir)
122
+ img = preprocess_image(img, args.bg_choice, 1.0, (127, 127, 127))
123
+ os.makedirs(args.outdir, exist_ok=True)
124
+ img.save(args.outdir+"preprocessed_image.png")
125
+
126
+ crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
127
+ specs = json.load(open("configs/specs_objaverse_total.json"))
128
+ model = CRM(specs).to("cuda")
129
+ model.load_state_dict(torch.load(crm_path, map_location = "cuda"), strict=False)
130
+
131
+ stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config
132
+ stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config
133
+ stage2_sampler_config = stage2_config.sampler
134
+ stage1_sampler_config = stage1_config.sampler
135
+
136
+ stage1_model_config = stage1_config.models
137
+ stage2_model_config = stage2_config.models
138
+
139
+ xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth")
140
+ pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth")
141
+ stage1_model_config.resume = pixel_path
142
+ stage2_model_config.resume = xyz_path
143
+
144
+ pipeline = TwoStagePipeline(
145
+ stage1_model_config,
146
+ stage2_model_config,
147
+ stage1_sampler_config,
148
+ stage2_sampler_config,
149
+ )
150
+
151
+ rt_dict = pipeline(img, scale=args.scale, step=args.step)
152
+ stage1_images = rt_dict["stage1_images"]
153
+ stage2_images = rt_dict["stage2_images"]
154
+ np_imgs = np.concatenate(stage1_images, 1)
155
+ np_xyzs = np.concatenate(stage2_images, 1)
156
+ Image.fromarray(np_imgs).save(args.outdir+"pixel_images.png")
157
+ Image.fromarray(np_xyzs).save(args.outdir+"xyz_images.png")
158
+
159
+ glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, "cuda")
160
+ shutil.copy(obj_path, args.outdir+"output3d.zip")