import os import shutil import sys import warnings import random import time import logging import fal_client import base64 import numpy as np import math import scipy import requests import torch import torchvision import gradio as gr import argparse import spaces from PIL import Image, ImageFilter, ImageOps, ImageDraw, ImageFont from io import BytesIO from typing import Dict, List, Tuple, Union, Optional # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) # Download model weights only if they don't exist if not os.path.exists("groundingdino_swint_ogc.pth"): os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth") if not os.path.exists("sam_hq_vit_l.pth"): os.system("wget https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth") # Add paths sys.path.append(os.path.join(os.getcwd(), "GroundingDINO")) sys.path.append(os.path.join(os.getcwd(), "sam-hq")) warnings.filterwarnings("ignore") # Grounding DINO import GroundingDINO.groundingdino.datasets.transforms as T from GroundingDINO.groundingdino.models import build_model from GroundingDINO.groundingdino.util.slconfig import SLConfig from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # segment anything from segment_anything import build_sam_vit_l, SamPredictor # Constants CONFIG_FILE = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py' GROUNDINGDINO_CHECKPOINT = "groundingdino_swint_ogc.pth" SAM_CHECKPOINT = 'sam_hq_vit_l.pth' OUTPUT_DIR = "outputs" # Global variables for model caching _models = { 'groundingdino': None, 'sam_predictor': None } # Enable GPU if available with proper error handling try: device = 'cuda' if torch.cuda.is_available() else 'cpu' logger.info(f"Using device: {device}") except Exception as e: logger.warning(f"Error detecting GPU, falling back to CPU: {e}") device = 'cpu' class ModelManager: """Manages model loading, unloading, and provides error handling""" @staticmethod def load_model(model_name: str) -> None: """Load a model if not already loaded""" try: if model_name == 'groundingdino' and _models['groundingdino'] is None: logger.info("Loading GroundingDINO model...") start_time = time.time() if not os.path.exists(GROUNDINGDINO_CHECKPOINT): raise FileNotFoundError(f"GroundingDINO checkpoint not found at {GROUNDINGDINO_CHECKPOINT}") args = SLConfig.fromfile(CONFIG_FILE) args.device = device model = build_model(args) checkpoint = torch.load(GROUNDINGDINO_CHECKPOINT, map_location="cpu") load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) logger.info(f"GroundingDINO load result: {load_res}") _ = model.eval() _models['groundingdino'] = model logger.info(f"GroundingDINO model loaded in {time.time() - start_time:.2f} seconds") elif model_name == 'sam' and _models['sam_predictor'] is None: logger.info("Loading SAM-HQ model...") start_time = time.time() if not os.path.exists(SAM_CHECKPOINT): raise FileNotFoundError(f"SAM checkpoint not found at {SAM_CHECKPOINT}") sam = build_sam_vit_l(checkpoint=SAM_CHECKPOINT) sam.to(device=device) _models['sam_predictor'] = SamPredictor(sam) logger.info(f"SAM-HQ model loaded in {time.time() - start_time:.2f} seconds") except Exception as e: logger.error(f"Error loading {model_name} model: {e}") raise RuntimeError(f"Failed to load {model_name} model: {e}") @staticmethod def get_model(model_name: str): """Get a model, loading it if necessary""" if model_name not in _models or _models[model_name] is None: ModelManager.load_model(model_name) return _models[model_name] @staticmethod def unload_model(model_name: str) -> None: """Unload a model to free memory""" if model_name in _models and _models[model_name] is not None: logger.info(f"Unloading {model_name} model") _models[model_name] = None if device == 'cuda': torch.cuda.empty_cache() def transform_image(image_pil: Image.Image) -> torch.Tensor: """Transform PIL image for GroundingDINO""" transform = T.Compose([ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) image, _ = transform(image_pil, None) # 3, h, w return image def get_grounding_output( image: torch.Tensor, caption: str, box_threshold: float, text_threshold: float, with_logits: bool = True ) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: """Run GroundingDINO to get bounding boxes from text prompt""" try: model = ModelManager.get_model('groundingdino') # Format caption caption = caption.lower().strip() if not caption.endswith("."): caption = caption + "." with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) # Filter output logits_filt = logits.clone() boxes_filt = boxes.clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] # num_filt, 256 boxes_filt = boxes_filt[filt_mask] # num_filt, 4 # Get phrases tokenizer = model.tokenizer tokenized = tokenizer(caption) pred_phrases = [] scores = [] for logit, box in zip(logits_filt, boxes_filt): pred_phrase = get_phrases_from_posmap( logit > text_threshold, tokenized, tokenizer) if with_logits: pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) scores.append(logit.max().item()) return boxes_filt, torch.Tensor(scores), pred_phrases except Exception as e: logger.error(f"Error in grounding output: {e}") # Return empty results instead of crashing return torch.Tensor([]), torch.Tensor([]), [] def draw_mask(mask: np.ndarray, draw: ImageDraw.Draw) -> None: """Draw mask on image""" color = (255, 255, 255, 255) nonzero_coords = np.transpose(np.nonzero(mask)) for coord in nonzero_coords: draw.point(coord[::-1], fill=color) def draw_box(box: torch.Tensor, draw: ImageDraw.Draw, label: Optional[str]) -> None: """Draw bounding box on image""" color = tuple(np.random.randint(0, 255, size=3).tolist()) draw.rectangle(((box[0], box[1]), (box[2], box[3])), outline=color, width=2) if label: font = ImageFont.load_default() if hasattr(font, "getbbox"): bbox = draw.textbbox((box[0], box[1]), str(label), font) else: w, h = draw.textsize(str(label), font) bbox = (box[0], box[1], w + box[0], box[1] + h) draw.rectangle(bbox, fill=color) draw.text((box[0], box[1]), str(label), fill="white") def run_grounded_sam(input_image, product): """Main function to run GroundingDINO and SAM-HQ""" # Create output directory os.makedirs(OUTPUT_DIR, exist_ok=True) text_prompt = product task_type = 'text' box_threshold = 0.3 text_threshold = 0.25 iou_threshold = 0.8 hq_token_only = True # Process input image if isinstance(input_image, dict): # Input from gradio sketch component scribble = np.array(input_image["mask"]) image_pil = input_image["image"].convert("RGB") else: # Direct image input image_pil = input_image.convert("RGB") if input_image else None scribble = None if image_pil is None: logger.error("No input image provided") return [Image.new('RGB', (400, 300), color='gray')] # Transform image for GroundingDINO transformed_image = transform_image(image_pil) # Load models as needed ModelManager.load_model('groundingdino') size = image_pil.size H, W = size[1], size[0] # Run GroundingDINO with provided text boxes_filt, scores, pred_phrases = get_grounding_output( transformed_image, text_prompt, box_threshold, text_threshold ) if boxes_filt is not None: # Scale boxes to image dimensions for i in range(boxes_filt.size(0)): boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 boxes_filt[i][2:] += boxes_filt[i][:2] # Apply non-maximum suppression if we have multiple boxes if boxes_filt.size(0) > 1: logger.info(f"Before NMS: {boxes_filt.shape[0]} boxes") nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist() boxes_filt = boxes_filt[nms_idx] pred_phrases = [pred_phrases[idx] for idx in nms_idx] logger.info(f"After NMS: {boxes_filt.shape[0]} boxes") # Load SAM model ModelManager.load_model('sam') sam_predictor = ModelManager.get_model('sam_predictor') # Set image for SAM image = np.array(image_pil) sam_predictor.set_image(image) # Run SAM # Use boxes for these task types if boxes_filt.size(0) == 0: logger.warning("No boxes detected") return [image_pil, Image.new('RGBA', size, color=(0, 0, 0, 0))] transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device) masks, _, _ = sam_predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, hq_token_only=hq_token_only, ) # Create mask image mask_image = Image.new('RGBA', size, color=(0, 0, 0, 0)) mask_draw = ImageDraw.Draw(mask_image) # Draw masks for mask in masks: draw_mask(mask[0].cpu().numpy(), mask_draw) # Draw boxes and points on original image image_draw = ImageDraw.Draw(image_pil) for box, label in zip(boxes_filt, pred_phrases): draw_box(box, image_draw, label) return mask_image # except Exception as e: # logger.error(f"Error in run_grounded_sam: {e}") # # Return original image on error # if isinstance(input_image, dict) and "image" in input_image: # return [input_image["image"], Image.new('RGBA', input_image["image"].size, color=(0, 0, 0, 0))] # elif isinstance(input_image, Image.Image): # return [input_image, Image.new('RGBA', input_image.size, color=(0, 0, 0, 0))] # else: # return [Image.new('RGB', (400, 300), color='gray'), Image.new('RGBA', (400, 300), color=(0, 0, 0, 0))] def split_image_with_alpha(image): image = image.convert("RGB") return image def gaussian_blur(image, radius=10): """Apply Gaussian blur to image.""" blurred = image.filter(ImageFilter.GaussianBlur(radius=10)) return blurred def invert_image(image): img_inverted = ImageOps.invert(image) return img_inverted def expand_mask(mask, expand, tapered_corners): # Ensure mask is in grayscale (mode 'L') mask = mask.convert("L") # Convert to NumPy array mask_np = np.array(mask) # Define kernel c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], [1, 1, 1], [c, 1, c]], dtype=np.uint8) # Perform dilation or erosion based on expand value if expand > 0: for _ in range(expand): mask_np = scipy.ndimage.grey_dilation(mask_np, footprint=kernel) elif expand < 0: for _ in range(abs(expand)): mask_np = scipy.ndimage.grey_erosion(mask_np, footprint=kernel) # Convert back to PIL image return Image.fromarray(mask_np, mode="L") def image_blend_by_mask(image_a, image_b, mask, blend_percentage): # Ensure images have the same size and mode image_a = image_a.convert('RGB') image_b = image_b.convert('RGB') mask = mask.convert('L') # Resize images if they don't match if image_a.size != image_b.size: image_b = image_b.resize(image_a.size, Image.LANCZOS) # Ensure mask has the same size if mask.size != image_a.size: mask = mask.resize(image_a.size, Image.LANCZOS) # Invert mask mask = ImageOps.invert(mask) # Mask image masked_img = Image.composite(image_a, image_b, mask) # Blend image blend_mask = Image.new(mode="L", size=image_a.size, color=(round(blend_percentage * 255))) blend_mask = ImageOps.invert(blend_mask) img_result = Image.composite(image_a, masked_img, blend_mask) del image_a, image_b, blend_mask, mask return img_result def blend_images(image_a, image_b, blend_percentage): """Blend img_b over image_a using the normal mode with a blend percentage.""" img_a = image_a.convert("RGBA") img_b = image_b.convert("RGBA") # Blend img_b over img_a using alpha_composite (normal blend mode) out_image = Image.alpha_composite(img_a, img_b) out_image = out_image.convert("RGB") # Create blend mask blend_mask = Image.new("L", image_a.size, round(blend_percentage * 255)) blend_mask = ImageOps.invert(blend_mask) # Invert the mask # Apply composite blend result = Image.composite(image_a, out_image, blend_mask) return result def apply_image_levels(image, black_level, mid_level, white_level): levels = AdjustLevels(black_level, mid_level, white_level) adjusted_image = levels.adjust(image) return adjusted_image class AdjustLevels: def __init__(self, min_level, mid_level, max_level): self.min_level = min_level self.mid_level = mid_level self.max_level = max_level def adjust(self, im): im_arr = np.array(im).astype(np.float32) im_arr[im_arr < self.min_level] = self.min_level im_arr = (im_arr - self.min_level) * \ (255 / (self.max_level - self.min_level)) im_arr = np.clip(im_arr, 0, 255) # mid-level adjustment gamma = math.log(0.5) / math.log((self.mid_level - self.min_level) / (self.max_level - self.min_level)) im_arr = np.power(im_arr / 255, gamma) * 255 im_arr = im_arr.astype(np.uint8) im = Image.fromarray(im_arr) return im def resize_image(image, scaling_factor=1): image = image.resize((int(image.width * scaling_factor), int(image.height * scaling_factor))) return image def upscale_image(image, size): new_image = image.resize((size, size), Image.LANCZOS) return new_image def resize_to_square(image, size=1024): # Load image if a file path is provided if isinstance(image, str): img = Image.open(image).convert("RGBA") else: img = image.convert("RGBA") # If already an Image object # Resize while maintaining aspect ratio img.thumbnail((size, size), Image.LANCZOS) # Create a transparent square canvas square_img = Image.new("RGBA", (size, size), (0, 0, 0, 0)) # Calculate the position to paste the resized image (centered) x_offset = (size - img.width) // 2 y_offset = (size - img.height) // 2 # Extract the alpha channel as a mask mask = img.split()[3] if img.mode == "RGBA" else None # Paste the resized image onto the square canvas with the correct transparency mask square_img.paste(img, (x_offset, y_offset), mask) return square_img def encode_image(image): buffer = BytesIO() image.save(buffer, format="PNG") encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8") return f"data:image/png;base64,{encoded_image}" def generate_ai_bg(input_img, prompt): # input_img = resize_image(input_img, 0.01) hf_input_img = encode_image(input_img) handler = fal_client.submit( "fal-ai/iclight-v2", arguments={ "prompt": prompt, "image_url": hf_input_img }, webhook_url="https://optional.webhook.url/for/results", ) request_id = handler.request_id status = fal_client.status("fal-ai/iclight-v2", request_id, with_logs=True) result = fal_client.result("fal-ai/iclight-v2", request_id) relight_img_path = result['images'][0]['url'] response = requests.get(relight_img_path, stream=True) relight_img = Image.open(BytesIO(response.content)).convert("RGBA") # from gradio_client import Client, handle_file # client = Client("lllyasviel/iclight-v2-vary") # result = client.predict( # input_fg=handle_file(input_img), # bg_source="None", # prompt=prompt, # image_width=256, # image_height=256, # num_samples=1, # seed=12345, # steps=25, # n_prompt="lowres, bad anatomy, bad hands, cropped, worst quality", # cfg=2, # gs=5, # enable_hr_fix=True, # hr_downscale=0.5, # lowres_denoise=0.8, # highres_denoise=0.99, # api_name="/process" # ) # print(result) # relight_img_path = result[0][0]['image'] # relight_img = Image.open(relight_img_path).convert("RGBA") return relight_img def blend_details(input_image, relit_image, masked_image, product, scaling_factor=1): # input_image = resize_image(input_image) # relit_image = resize_image(relit_image) # masked_image = resize_image(masked_image) masked_image_rgb = split_image_with_alpha(masked_image) masked_image_blurred = gaussian_blur(masked_image_rgb, radius=10) grow_mask = expand_mask(masked_image_blurred, -15, True) # grow_mask.save("output/grow_mask.png") # Split images and get RGB channels input_image_rgb = split_image_with_alpha(input_image) input_blurred = gaussian_blur(input_image_rgb, radius=10) input_inverted = invert_image(input_image_rgb) # input_blurred.save("output/input_blurred.png") # input_inverted.save("output/input_inverted.png") # Add blurred and inverted images input_blend_1 = blend_images(input_inverted, input_blurred, blend_percentage=0.5) input_blend_1_inverted = invert_image(input_blend_1) input_blend_2 = blend_images(input_blurred, input_blend_1_inverted, blend_percentage=1.0) # input_blend_2.save("output/input_blend_2.png") # Process relit image relit_image_rgb = split_image_with_alpha(relit_image) relit_blurred = gaussian_blur(relit_image_rgb, radius=10) relit_inverted = invert_image(relit_image_rgb) # relit_blurred.save("output/relit_blurred.png") # relit_inverted.save("output/relit_inverted.png") # Add blurred and inverted relit images relit_blend_1 = blend_images(relit_inverted, relit_blurred, blend_percentage=0.5) relit_blend_1_inverted = invert_image(relit_blend_1) relit_blend_2 = blend_images(relit_blurred, relit_blend_1_inverted, blend_percentage=1.0) # relit_blend_2.save("output/relit_blend_2.png") high_freq_comp = image_blend_by_mask(relit_blend_2, input_blend_2, grow_mask, blend_percentage=1.0) # high_freq_comp.save("output/high_freq_comp.png") comped_image = blend_images(relit_blurred, high_freq_comp, blend_percentage=0.65) # comped_image.save("output/comped_image.png") final_image = apply_image_levels(comped_image, black_level=83, mid_level=128, white_level=172) # final_image.save("output/final_image.png") return final_image @spaces.GPU def generate_image(input_image_path, prompt): # resized_input_img = resize_to_square(input_image_path, 256) # resized_input_img_path = '/tmp/gradio/resized_input_img.png' # resized_input_img.convert("RGBA").save(resized_input_img_path, "PNG") # ai_gen_image = generate_ai_bg(resized_input_img, prompt) # upscaled_ai_image = upscale_image(ai_gen_image, 8192) # upscaled_input_image = upscale_image(resized_input_img, 8192) # mask_input_image = run_grounded_sam(upscaled_input_image) # final_image = blend_details(upscaled_input_image, upscaled_ai_image, mask_input_image) # FAL resized_input_img = resize_to_square(input_image_path, 1024) ai_gen_image = generate_ai_bg(resized_input_img, prompt) mask_input_image = run_grounded_sam(resized_input_img, product) final_image = blend_details(resized_input_img, ai_gen_image, mask_input_image, product) return final_image def create_ui(): """Create Gradio UI for CarViz demo""" with gr.Blocks(title="CarViz Demo") as block: gr.Markdown(""" # CarViz """) with gr.Row(): with gr.Column(): input_image_path = gr.Image(type="filepath", label="image") product = gr.Textbox(label="Product", placeholder="Enter what your product is here...") prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") run_button = gr.Button(value='Run') with gr.Column(): output_image = gr.Image(label="Generated Image") # Run button run_button.click( fn=generate_image, inputs=[ input_image_path, product, prompt ], outputs=[output_image] ) return block if __name__ == "__main__": parser = argparse.ArgumentParser("Carviz demo", add_help=True) parser.add_argument("--debug", action="store_true", help="using debug mode") parser.add_argument("--share", action="store_true", help="share the app") parser.add_argument('--no-gradio-queue', action="store_true", help="disable gradio queue") parser.add_argument('--port', type=int, default=7860, help="port to run the app") parser.add_argument('--host', type=str, default="0.0.0.0", help="host to run the app") args = parser.parse_args() logger.info(f"Starting CarViz demo with args: {args}") # Check for model files if not os.path.exists(GROUNDINGDINO_CHECKPOINT): logger.warning(f"GroundingDINO checkpoint not found at {GROUNDINGDINO_CHECKPOINT}") if not os.path.exists(SAM_CHECKPOINT): logger.warning(f"SAM-HQ checkpoint not found at {SAM_CHECKPOINT}") # Create app block = create_ui() if not args.no_gradio_queue: block = block.queue() # Launch app try: block.launch( debug=args.debug, share=args.share, show_error=True, server_name=args.host, server_port=args.port ) except Exception as e: logger.error(f"Error launching app: {e}")