Spaces:
Running
Running
| import os | |
| import cv2 | |
| import time | |
| import tqdm | |
| import numpy as np | |
| import dearpygui.dearpygui as dpg | |
| import torch | |
| import torch.nn.functional as F | |
| import trimesh | |
| import rembg | |
| from cam_utils import orbit_camera, OrbitCamera | |
| from mesh_renderer import Renderer | |
| # from kiui.lpips import LPIPS | |
| class GUI: | |
| def __init__(self, opt): | |
| self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. | |
| self.gui = opt.gui # enable gui | |
| self.W = opt.W | |
| self.H = opt.H | |
| self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) | |
| self.mode = "image" | |
| self.seed = "random" | |
| self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) | |
| self.need_update = True # update buffer_image | |
| # models | |
| self.device = torch.device("cuda") | |
| self.bg_remover = None | |
| self.guidance_sd = None | |
| self.guidance_zero123 = None | |
| self.enable_sd = False | |
| self.enable_zero123 = False | |
| # renderer | |
| self.renderer = Renderer(opt).to(self.device) | |
| # input image | |
| self.input_img = None | |
| self.input_mask = None | |
| self.input_img_torch = None | |
| self.input_mask_torch = None | |
| self.overlay_input_img = False | |
| self.overlay_input_img_ratio = 0.5 | |
| # input text | |
| self.prompt = "" | |
| self.negative_prompt = "" | |
| # training stuff | |
| self.training = False | |
| self.optimizer = None | |
| self.step = 0 | |
| self.train_steps = 1 # steps per rendering loop | |
| # self.lpips_loss = LPIPS(net='vgg').to(self.device) | |
| # load input data from cmdline | |
| if self.opt.input is not None: | |
| self.load_input(self.opt.input) | |
| # override prompt from cmdline | |
| if self.opt.prompt is not None: | |
| self.prompt = self.opt.prompt | |
| if self.gui: | |
| dpg.create_context() | |
| self.register_dpg() | |
| self.test_step() | |
| def __del__(self): | |
| if self.gui: | |
| dpg.destroy_context() | |
| def seed_everything(self): | |
| try: | |
| seed = int(self.seed) | |
| except: | |
| seed = np.random.randint(0, 1000000) | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = True | |
| self.last_seed = seed | |
| def prepare_train(self): | |
| self.step = 0 | |
| # setup training | |
| self.optimizer = torch.optim.Adam(self.renderer.get_params()) | |
| # default camera | |
| pose = orbit_camera(self.opt.elevation, 0, self.opt.radius) | |
| self.fixed_cam = (pose, self.cam.perspective) | |
| self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != "" | |
| self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None | |
| # lazy load guidance model | |
| if self.guidance_sd is None and self.enable_sd: | |
| print(f"[INFO] loading SD...") | |
| from guidance.sd_utils import StableDiffusion | |
| self.guidance_sd = StableDiffusion(self.device) | |
| print(f"[INFO] loaded SD!") | |
| if self.guidance_zero123 is None and self.enable_zero123: | |
| print(f"[INFO] loading zero123...") | |
| from guidance.zero123_utils import Zero123 | |
| self.guidance_zero123 = Zero123(self.device) | |
| print(f"[INFO] loaded zero123!") | |
| # input image | |
| if self.input_img is not None: | |
| self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device) | |
| self.input_img_torch = F.interpolate( | |
| self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False | |
| ) | |
| self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device) | |
| self.input_mask_torch = F.interpolate( | |
| self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False | |
| ) | |
| self.input_img_torch_channel_last = self.input_img_torch[0].permute(1,2,0).contiguous() | |
| # prepare embeddings | |
| with torch.no_grad(): | |
| if self.enable_sd: | |
| self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt]) | |
| if self.enable_zero123: | |
| self.guidance_zero123.get_img_embeds(self.input_img_torch) | |
| def train_step(self): | |
| starter = torch.cuda.Event(enable_timing=True) | |
| ender = torch.cuda.Event(enable_timing=True) | |
| starter.record() | |
| for _ in range(self.train_steps): | |
| self.step += 1 | |
| step_ratio = min(1, self.step / self.opt.iters_refine) | |
| loss = 0 | |
| ### known view | |
| if self.input_img_torch is not None: | |
| ssaa = min(2.0, max(0.125, 2 * np.random.random())) | |
| out = self.renderer.render(*self.fixed_cam, self.opt.ref_size, self.opt.ref_size, ssaa=ssaa) | |
| # rgb loss | |
| image = out["image"] # [H, W, 3] in [0, 1] | |
| valid_mask = ((out["alpha"] > 0) & (out["viewcos"] > 0.5)).detach() | |
| loss = loss + F.mse_loss(image * valid_mask, self.input_img_torch_channel_last * valid_mask) | |
| ### novel view (manual batch) | |
| render_resolution = 512 | |
| images = [] | |
| vers, hors, radii = [], [], [] | |
| # avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30] | |
| min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation) | |
| max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation) | |
| for _ in range(self.opt.batch_size): | |
| # render random view | |
| ver = np.random.randint(min_ver, max_ver) | |
| hor = np.random.randint(-180, 180) | |
| radius = 0 | |
| vers.append(ver) | |
| hors.append(hor) | |
| radii.append(radius) | |
| pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius) | |
| # random render resolution | |
| ssaa = min(2.0, max(0.125, 2 * np.random.random())) | |
| out = self.renderer.render(pose, self.cam.perspective, render_resolution, render_resolution, ssaa=ssaa) | |
| image = out["image"] # [H, W, 3] in [0, 1] | |
| image = image.permute(2,0,1).contiguous().unsqueeze(0) # [1, 3, H, W] in [0, 1] | |
| images.append(image) | |
| images = torch.cat(images, dim=0) | |
| # import kiui | |
| # kiui.lo(hor, ver) | |
| # kiui.vis.plot_image(image) | |
| # guidance loss | |
| if self.enable_sd: | |
| # loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, step_ratio) | |
| refined_images = self.guidance_sd.refine(images, strength=0.6).float() | |
| refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False) | |
| loss = loss + self.opt.lambda_sd * F.mse_loss(images, refined_images) | |
| if self.enable_zero123: | |
| # loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.train_step(images, vers, hors, radii, step_ratio) | |
| refined_images = self.guidance_zero123.refine(images, vers, hors, radii, strength=0.6).float() | |
| refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False) | |
| loss = loss + self.opt.lambda_zero123 * F.mse_loss(images, refined_images) | |
| # loss = loss + self.opt.lambda_zero123 * self.lpips_loss(images, refined_images) | |
| # optimize step | |
| loss.backward() | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| ender.record() | |
| torch.cuda.synchronize() | |
| t = starter.elapsed_time(ender) | |
| self.need_update = True | |
| if self.gui: | |
| dpg.set_value("_log_train_time", f"{t:.4f}ms") | |
| dpg.set_value( | |
| "_log_train_log", | |
| f"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {loss.item():.4f}", | |
| ) | |
| # dynamic train steps (no need for now) | |
| # max allowed train time per-frame is 500 ms | |
| # full_t = t / self.train_steps * 16 | |
| # train_steps = min(16, max(4, int(16 * 500 / full_t))) | |
| # if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8: | |
| # self.train_steps = train_steps | |
| def test_step(self): | |
| # ignore if no need to update | |
| if not self.need_update: | |
| return | |
| starter = torch.cuda.Event(enable_timing=True) | |
| ender = torch.cuda.Event(enable_timing=True) | |
| starter.record() | |
| # should update image | |
| if self.need_update: | |
| # render image | |
| out = self.renderer.render(self.cam.pose, self.cam.perspective, self.H, self.W) | |
| buffer_image = out[self.mode] # [H, W, 3] | |
| if self.mode in ['depth', 'alpha']: | |
| buffer_image = buffer_image.repeat(1, 1, 3) | |
| if self.mode == 'depth': | |
| buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20) | |
| self.buffer_image = buffer_image.contiguous().clamp(0, 1).detach().cpu().numpy() | |
| # display input_image | |
| if self.overlay_input_img and self.input_img is not None: | |
| self.buffer_image = ( | |
| self.buffer_image * (1 - self.overlay_input_img_ratio) | |
| + self.input_img * self.overlay_input_img_ratio | |
| ) | |
| self.need_update = False | |
| ender.record() | |
| torch.cuda.synchronize() | |
| t = starter.elapsed_time(ender) | |
| if self.gui: | |
| dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)") | |
| dpg.set_value( | |
| "_texture", self.buffer_image | |
| ) # buffer must be contiguous, else seg fault! | |
| def load_input(self, file): | |
| # load image | |
| print(f'[INFO] load image from {file}...') | |
| img = cv2.imread(file, cv2.IMREAD_UNCHANGED) | |
| if img.shape[-1] == 3: | |
| if self.bg_remover is None: | |
| self.bg_remover = rembg.new_session() | |
| img = rembg.remove(img, session=self.bg_remover) | |
| img = cv2.resize( | |
| img, (self.W, self.H), interpolation=cv2.INTER_AREA | |
| ) | |
| img = img.astype(np.float32) / 255.0 | |
| self.input_mask = img[..., 3:] | |
| # white bg | |
| self.input_img = img[..., :3] * self.input_mask + ( | |
| 1 - self.input_mask | |
| ) | |
| # bgr to rgb | |
| self.input_img = self.input_img[..., ::-1].copy() | |
| # load prompt | |
| file_prompt = file.replace("_rgba.png", "_caption.txt") | |
| if os.path.exists(file_prompt): | |
| print(f'[INFO] load prompt from {file_prompt}...') | |
| with open(file_prompt, "r") as f: | |
| self.prompt = f.read().strip() | |
| def save_model(self): | |
| os.makedirs(self.opt.outdir, exist_ok=True) | |
| path = os.path.join(self.opt.outdir, self.opt.save_path + '.obj') | |
| self.renderer.export_mesh(path) | |
| print(f"[INFO] save model to {path}.") | |
| def register_dpg(self): | |
| ### register texture | |
| with dpg.texture_registry(show=False): | |
| dpg.add_raw_texture( | |
| self.W, | |
| self.H, | |
| self.buffer_image, | |
| format=dpg.mvFormat_Float_rgb, | |
| tag="_texture", | |
| ) | |
| ### register window | |
| # the rendered image, as the primary window | |
| with dpg.window( | |
| tag="_primary_window", | |
| width=self.W, | |
| height=self.H, | |
| pos=[0, 0], | |
| no_move=True, | |
| no_title_bar=True, | |
| no_scrollbar=True, | |
| ): | |
| # add the texture | |
| dpg.add_image("_texture") | |
| # dpg.set_primary_window("_primary_window", True) | |
| # control window | |
| with dpg.window( | |
| label="Control", | |
| tag="_control_window", | |
| width=600, | |
| height=self.H, | |
| pos=[self.W, 0], | |
| no_move=True, | |
| no_title_bar=True, | |
| ): | |
| # button theme | |
| with dpg.theme() as theme_button: | |
| with dpg.theme_component(dpg.mvButton): | |
| dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) | |
| dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) | |
| dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) | |
| dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) | |
| dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) | |
| # timer stuff | |
| with dpg.group(horizontal=True): | |
| dpg.add_text("Infer time: ") | |
| dpg.add_text("no data", tag="_log_infer_time") | |
| def callback_setattr(sender, app_data, user_data): | |
| setattr(self, user_data, app_data) | |
| # init stuff | |
| with dpg.collapsing_header(label="Initialize", default_open=True): | |
| # seed stuff | |
| def callback_set_seed(sender, app_data): | |
| self.seed = app_data | |
| self.seed_everything() | |
| dpg.add_input_text( | |
| label="seed", | |
| default_value=self.seed, | |
| on_enter=True, | |
| callback=callback_set_seed, | |
| ) | |
| # input stuff | |
| def callback_select_input(sender, app_data): | |
| # only one item | |
| for k, v in app_data["selections"].items(): | |
| dpg.set_value("_log_input", k) | |
| self.load_input(v) | |
| self.need_update = True | |
| with dpg.file_dialog( | |
| directory_selector=False, | |
| show=False, | |
| callback=callback_select_input, | |
| file_count=1, | |
| tag="file_dialog_tag", | |
| width=700, | |
| height=400, | |
| ): | |
| dpg.add_file_extension("Images{.jpg,.jpeg,.png}") | |
| with dpg.group(horizontal=True): | |
| dpg.add_button( | |
| label="input", | |
| callback=lambda: dpg.show_item("file_dialog_tag"), | |
| ) | |
| dpg.add_text("", tag="_log_input") | |
| # overlay stuff | |
| with dpg.group(horizontal=True): | |
| def callback_toggle_overlay_input_img(sender, app_data): | |
| self.overlay_input_img = not self.overlay_input_img | |
| self.need_update = True | |
| dpg.add_checkbox( | |
| label="overlay image", | |
| default_value=self.overlay_input_img, | |
| callback=callback_toggle_overlay_input_img, | |
| ) | |
| def callback_set_overlay_input_img_ratio(sender, app_data): | |
| self.overlay_input_img_ratio = app_data | |
| self.need_update = True | |
| dpg.add_slider_float( | |
| label="ratio", | |
| min_value=0, | |
| max_value=1, | |
| format="%.1f", | |
| default_value=self.overlay_input_img_ratio, | |
| callback=callback_set_overlay_input_img_ratio, | |
| ) | |
| # prompt stuff | |
| dpg.add_input_text( | |
| label="prompt", | |
| default_value=self.prompt, | |
| callback=callback_setattr, | |
| user_data="prompt", | |
| ) | |
| dpg.add_input_text( | |
| label="negative", | |
| default_value=self.negative_prompt, | |
| callback=callback_setattr, | |
| user_data="negative_prompt", | |
| ) | |
| # save current model | |
| with dpg.group(horizontal=True): | |
| dpg.add_text("Save: ") | |
| dpg.add_button( | |
| label="model", | |
| tag="_button_save_model", | |
| callback=self.save_model, | |
| ) | |
| dpg.bind_item_theme("_button_save_model", theme_button) | |
| dpg.add_input_text( | |
| label="", | |
| default_value=self.opt.save_path, | |
| callback=callback_setattr, | |
| user_data="save_path", | |
| ) | |
| # training stuff | |
| with dpg.collapsing_header(label="Train", default_open=True): | |
| # lr and train button | |
| with dpg.group(horizontal=True): | |
| dpg.add_text("Train: ") | |
| def callback_train(sender, app_data): | |
| if self.training: | |
| self.training = False | |
| dpg.configure_item("_button_train", label="start") | |
| else: | |
| self.prepare_train() | |
| self.training = True | |
| dpg.configure_item("_button_train", label="stop") | |
| # dpg.add_button( | |
| # label="init", tag="_button_init", callback=self.prepare_train | |
| # ) | |
| # dpg.bind_item_theme("_button_init", theme_button) | |
| dpg.add_button( | |
| label="start", tag="_button_train", callback=callback_train | |
| ) | |
| dpg.bind_item_theme("_button_train", theme_button) | |
| with dpg.group(horizontal=True): | |
| dpg.add_text("", tag="_log_train_time") | |
| dpg.add_text("", tag="_log_train_log") | |
| # rendering options | |
| with dpg.collapsing_header(label="Rendering", default_open=True): | |
| # mode combo | |
| def callback_change_mode(sender, app_data): | |
| self.mode = app_data | |
| self.need_update = True | |
| dpg.add_combo( | |
| ("image", "depth", "alpha", "normal"), | |
| label="mode", | |
| default_value=self.mode, | |
| callback=callback_change_mode, | |
| ) | |
| # fov slider | |
| def callback_set_fovy(sender, app_data): | |
| self.cam.fovy = np.deg2rad(app_data) | |
| self.need_update = True | |
| dpg.add_slider_int( | |
| label="FoV (vertical)", | |
| min_value=1, | |
| max_value=120, | |
| format="%d deg", | |
| default_value=np.rad2deg(self.cam.fovy), | |
| callback=callback_set_fovy, | |
| ) | |
| ### register camera handler | |
| def callback_camera_drag_rotate_or_draw_mask(sender, app_data): | |
| if not dpg.is_item_focused("_primary_window"): | |
| return | |
| dx = app_data[1] | |
| dy = app_data[2] | |
| self.cam.orbit(dx, dy) | |
| self.need_update = True | |
| def callback_camera_wheel_scale(sender, app_data): | |
| if not dpg.is_item_focused("_primary_window"): | |
| return | |
| delta = app_data | |
| self.cam.scale(delta) | |
| self.need_update = True | |
| def callback_camera_drag_pan(sender, app_data): | |
| if not dpg.is_item_focused("_primary_window"): | |
| return | |
| dx = app_data[1] | |
| dy = app_data[2] | |
| self.cam.pan(dx, dy) | |
| self.need_update = True | |
| def callback_set_mouse_loc(sender, app_data): | |
| if not dpg.is_item_focused("_primary_window"): | |
| return | |
| # just the pixel coordinate in image | |
| self.mouse_loc = np.array(app_data) | |
| with dpg.handler_registry(): | |
| # for camera moving | |
| dpg.add_mouse_drag_handler( | |
| button=dpg.mvMouseButton_Left, | |
| callback=callback_camera_drag_rotate_or_draw_mask, | |
| ) | |
| dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) | |
| dpg.add_mouse_drag_handler( | |
| button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan | |
| ) | |
| dpg.create_viewport( | |
| title="Gaussian3D", | |
| width=self.W + 600, | |
| height=self.H + (45 if os.name == "nt" else 0), | |
| resizable=False, | |
| ) | |
| ### global theme | |
| with dpg.theme() as theme_no_padding: | |
| with dpg.theme_component(dpg.mvAll): | |
| # set all padding to 0 to avoid scroll bar | |
| dpg.add_theme_style( | |
| dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core | |
| ) | |
| dpg.add_theme_style( | |
| dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core | |
| ) | |
| dpg.add_theme_style( | |
| dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core | |
| ) | |
| dpg.bind_item_theme("_primary_window", theme_no_padding) | |
| dpg.setup_dearpygui() | |
| ### register a larger font | |
| # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf | |
| if os.path.exists("LXGWWenKai-Regular.ttf"): | |
| with dpg.font_registry(): | |
| with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: | |
| dpg.bind_font(default_font) | |
| # dpg.show_metrics() | |
| dpg.show_viewport() | |
| def render(self): | |
| assert self.gui | |
| while dpg.is_dearpygui_running(): | |
| # update texture every frame | |
| if self.training: | |
| self.train_step() | |
| self.test_step() | |
| dpg.render_dearpygui_frame() | |
| # no gui mode | |
| def train(self, iters=500): | |
| if iters > 0: | |
| self.prepare_train() | |
| for i in tqdm.trange(iters): | |
| self.train_step() | |
| # save | |
| self.save_model() | |
| if __name__ == "__main__": | |
| import argparse | |
| from omegaconf import OmegaConf | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True, help="path to the yaml config file") | |
| args, extras = parser.parse_known_args() | |
| # override default config from cli | |
| opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) | |
| # auto find mesh from stage 1 | |
| if opt.mesh is None: | |
| default_path = os.path.join(opt.outdir, opt.save_path + '_mesh.obj') | |
| if os.path.exists(default_path): | |
| opt.mesh = default_path | |
| else: | |
| raise ValueError(f"Cannot find mesh from {default_path}, must specify --mesh explicitly!") | |
| gui = GUI(opt) | |
| if opt.gui: | |
| gui.render() | |
| else: | |
| gui.train(opt.iters_refine) | |