import spaces import gradio as gr import numpy as np import os import torch import random import subprocess import requests import json subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights from PIL import Image from data.data_utils import add_special_tokens, pil_img2rgb from data.transforms import ImageTransform from inferencer import InterleaveInferencer from modeling.autoencoder import load_ae from modeling.bagel import ( BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel ) from modeling.qwen2 import Qwen2Tokenizer from huggingface_hub import snapshot_download # Get Brave Search API key BSEARCH_API = os.getenv("BSEARCH_API") save_dir = "./model_weights" repo_id = "ByteDance-Seed/BAGEL-7B-MoT" cache_dir = save_dir + "/cache" snapshot_download( cache_dir=cache_dir, local_dir=save_dir, repo_id=repo_id, local_dir_use_symlinks=False, resume_download=True, allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"], ) # Model Initialization model_path = save_dir llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json")) llm_config.qk_norm = True llm_config.tie_word_embeddings = False llm_config.layer_module = "Qwen2MoTDecoderLayer" vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json")) vit_config.rope = False vit_config.num_hidden_layers -= 1 vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors")) config = BagelConfig( visual_gen=True, visual_und=True, llm_config=llm_config, vit_config=vit_config, vae_config=vae_config, vit_max_num_patch_per_side=70, connector_act='gelu_pytorch_tanh', latent_patch_size=2, max_latent_size=64, ) with init_empty_weights(): language_model = Qwen2ForCausalLM(llm_config) vit_model = SiglipVisionModel(vit_config) model = Bagel(language_model, vit_model, config) model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) tokenizer = Qwen2Tokenizer.from_pretrained(model_path) tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) vae_transform = ImageTransform(1024, 512, 16) vit_transform = ImageTransform(980, 224, 14) # Model Loading and Multi GPU Infernece Preparing device_map = infer_auto_device_map( model, max_memory={i: "80GiB" for i in range(torch.cuda.device_count())}, no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], ) same_device_modules = [ 'language_model.model.embed_tokens', 'time_embedder', 'latent_pos_embed', 'vae2llm', 'llm2vae', 'connector', 'vit_pos_embed' ] if torch.cuda.device_count() == 1: first_device = device_map.get(same_device_modules[0], "cuda:0") for k in same_device_modules: if k in device_map: device_map[k] = first_device else: device_map[k] = "cuda:0" else: first_device = device_map.get(same_device_modules[0]) for k in same_device_modules: if k in device_map: device_map[k] = first_device model = load_checkpoint_and_dispatch( model, checkpoint=os.path.join(model_path, "ema.safetensors"), device_map=device_map, offload_buffers=True, offload_folder="offload", dtype=torch.bfloat16, force_hooks=True, ).eval() # Inferencer Preparing inferencer = InterleaveInferencer( model=model, vae_model=vae_model, tokenizer=tokenizer, vae_transform=vae_transform, vit_transform=vit_transform, new_token_ids=new_token_ids, ) # Brave Search function def brave_search(query): """Perform a web search using Brave Search API.""" if not BSEARCH_API: return None try: headers = { "Accept": "application/json", "X-Subscription-Token": BSEARCH_API } url = "https://api.search.brave.com/res/v1/web/search" params = { "q": query, "count": 5 } response = requests.get(url, headers=headers, params=params) response.raise_for_status() data = response.json() results = [] if "web" in data and "results" in data["web"]: for idx, result in enumerate(data["web"]["results"][:5], 1): title = result.get("title", "No title") url = result.get("url", "") description = result.get("description", "No description") results.append(f"{idx}. {title}\nURL: {url}\n{description}") if results: return "\n\n".join(results) else: return None except Exception as e: print(f"Search error: {str(e)}") return None def enhance_prompt_with_search(prompt, use_search=False): """Enhance prompt with web search results if enabled.""" if not use_search or not BSEARCH_API: return prompt search_results = brave_search(prompt) if search_results: enhanced_prompt = f"{prompt}\n\n[Web Search Context]:\n{search_results}\n\n[Generate based on the above context and original prompt]" return enhanced_prompt return prompt def set_seed(seed): """Set random seeds for reproducibility""" if seed > 0: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False return seed # Text to Image function with thinking option and hyperparameters @spaces.GPU(duration=90) def text_to_image(prompt, use_web_search=False, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4, timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0, cfg_renorm_type="global", max_think_token_n=1024, do_sample=False, text_temperature=0.3, seed=0, image_ratio="1:1"): # Set seed for reproducibility set_seed(seed) # Enhance prompt with search if enabled enhanced_prompt = enhance_prompt_with_search(prompt, use_web_search) if image_ratio == "1:1": image_shapes = (1024, 1024) elif image_ratio == "4:3": image_shapes = (768, 1024) elif image_ratio == "3:4": image_shapes = (1024, 768) elif image_ratio == "16:9": image_shapes = (576, 1024) elif image_ratio == "9:16": image_shapes = (1024, 576) # Set hyperparameters inference_hyper = dict( max_think_token_n=max_think_token_n if show_thinking else 1024, do_sample=do_sample if show_thinking else False, text_temperature=text_temperature if show_thinking else 0.3, cfg_text_scale=cfg_text_scale, cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 timestep_shift=timestep_shift, num_timesteps=num_timesteps, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, image_shapes=image_shapes, ) result = {"text": "", "image": None} # Call inferencer with or without think parameter based on user choice for i in inferencer(text=enhanced_prompt, think=show_thinking, understanding_output=False, **inference_hyper): if type(i) == str: result["text"] += i else: result["image"] = i yield result["image"], result.get("text", None) # Image Understanding function with thinking option and hyperparameters @spaces.GPU(duration=90) def image_understanding(image: Image.Image, prompt: str, use_web_search=False, show_thinking=False, do_sample=False, text_temperature=0.3, max_new_tokens=512): if image is None: return "Please upload an image." if isinstance(image, np.ndarray): image = Image.fromarray(image) image = pil_img2rgb(image) # Enhance prompt with search if enabled enhanced_prompt = enhance_prompt_with_search(prompt, use_web_search) # Set hyperparameters inference_hyper = dict( do_sample=do_sample, text_temperature=text_temperature, max_think_token_n=max_new_tokens, # Set max_length ) result = {"text": "", "image": None} # Use show_thinking parameter to control thinking process for i in inferencer(image=image, text=enhanced_prompt, think=show_thinking, understanding_output=True, **inference_hyper): if type(i) == str: result["text"] += i else: result["image"] = i yield result["text"] # Image Editing function with thinking option and hyperparameters @spaces.GPU(duration=90) def edit_image(image: Image.Image, prompt: str, use_web_search=False, show_thinking=False, cfg_text_scale=4.0, cfg_img_scale=2.0, cfg_interval=0.0, timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0, cfg_renorm_type="text_channel", max_think_token_n=1024, do_sample=False, text_temperature=0.3, seed=0): # Set seed for reproducibility set_seed(seed) if image is None: return "Please upload an image.", "" if isinstance(image, np.ndarray): image = Image.fromarray(image) image = pil_img2rgb(image) # Enhance prompt with search if enabled enhanced_prompt = enhance_prompt_with_search(prompt, use_web_search) # Set hyperparameters inference_hyper = dict( max_think_token_n=max_think_token_n if show_thinking else 1024, do_sample=do_sample if show_thinking else False, text_temperature=text_temperature if show_thinking else 0.3, cfg_text_scale=cfg_text_scale, cfg_img_scale=cfg_img_scale, cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 timestep_shift=timestep_shift, num_timesteps=num_timesteps, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, ) # Include thinking parameter based on user choice result = {"text": "", "image": None} for i in inferencer(image=image, text=enhanced_prompt, think=show_thinking, understanding_output=False, **inference_hyper): if type(i) == str: result["text"] += i else: result["image"] = i yield result["image"], result.get("text", "") # Helper function to load example images def load_example_image(image_path): try: return Image.open(image_path) except Exception as e: print(f"Error loading example image: {e}") return None # Enhanced CSS for visual improvements custom_css = """ /* Modern gradient background */ .gradio-container { background: linear-gradient(135deg, #1e3c72 0%, #2a5298 50%, #3a6fb0 100%); min-height: 100vh; } /* Main container with glassmorphism */ .container { backdrop-filter: blur(10px); background: rgba(255, 255, 255, 0.1); border-radius: 20px; padding: 30px; margin: 20px auto; max-width: 1400px; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2); } /* Header styling */ h1 { background: linear-gradient(90deg, #ffffff 0%, #e0e0e0 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 3.5em; text-align: center; margin-bottom: 30px; font-weight: 800; text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.3); } /* Tab styling */ .tabs { background: rgba(255, 255, 255, 0.15); border-radius: 15px; padding: 10px; margin-bottom: 20px; } .tab-nav { background: rgba(255, 255, 255, 0.2) !important; border-radius: 10px !important; padding: 5px !important; } .tab-nav button { background: transparent !important; color: white !important; border: none !important; padding: 10px 20px !important; margin: 0 5px !important; border-radius: 8px !important; font-weight: 600 !important; transition: all 0.3s ease !important; } .tab-nav button.selected { background: rgba(255, 255, 255, 0.3) !important; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2) !important; } .tab-nav button:hover { background: rgba(255, 255, 255, 0.25) !important; } /* Input field styling */ .textbox, .image-container { background: rgba(255, 255, 255, 0.95) !important; border: 2px solid rgba(255, 255, 255, 0.3) !important; border-radius: 12px !important; padding: 15px !important; color: #333 !important; font-size: 16px !important; transition: all 0.3s ease !important; } .textbox:focus { border-color: #3a6fb0 !important; box-shadow: 0 0 20px rgba(58, 111, 176, 0.4) !important; } /* Button styling */ .primary { background: linear-gradient(135deg, #4CAF50 0%, #45a049 100%) !important; color: white !important; border: none !important; padding: 12px 30px !important; border-radius: 10px !important; font-weight: 600 !important; font-size: 16px !important; cursor: pointer !important; transition: all 0.3s ease !important; box-shadow: 0 4px 15px rgba(76, 175, 80, 0.3) !important; } .primary:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(76, 175, 80, 0.4) !important; } /* Checkbox styling */ .checkbox-group { background: rgba(255, 255, 255, 0.1) !important; padding: 10px 15px !important; border-radius: 8px !important; margin: 10px 0 !important; } .checkbox-group label { color: white !important; font-weight: 500 !important; } /* Accordion styling */ .accordion { background: rgba(255, 255, 255, 0.1) !important; border-radius: 12px !important; margin: 15px 0 !important; border: 1px solid rgba(255, 255, 255, 0.2) !important; } .accordion-header { background: rgba(255, 255, 255, 0.15) !important; color: white !important; padding: 12px 20px !important; border-radius: 10px !important; font-weight: 600 !important; } /* Slider styling */ .slider { background: rgba(255, 255, 255, 0.2) !important; border-radius: 5px !important; } .slider .handle { background: white !important; border: 3px solid #3a6fb0 !important; } /* Image output styling */ .image-frame { border-radius: 15px !important; overflow: hidden !important; box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3) !important; background: rgba(255, 255, 255, 0.1) !important; padding: 10px !important; } /* Footer links */ a { color: #64b5f6 !important; text-decoration: none !important; font-weight: 500 !important; transition: color 0.3s ease !important; } a:hover { color: #90caf9 !important; } /* Web search info box */ .web-search-info { background: linear-gradient(135deg, rgba(255, 193, 7, 0.2) 0%, rgba(255, 152, 0, 0.2) 100%); border: 2px solid rgba(255, 193, 7, 0.5); border-radius: 10px; padding: 15px; margin: 10px 0; color: white; } .web-search-info h4 { margin: 0 0 10px 0; color: #ffd54f; font-size: 1.2em; } .web-search-info p { margin: 5px 0; font-size: 0.95em; line-height: 1.4; } /* Loading animation */ .generating { border-color: #4CAF50 !important; animation: pulse 2s infinite !important; } @keyframes pulse { 0% { box-shadow: 0 0 0 0 rgba(76, 175, 80, 0.7); } 70% { box-shadow: 0 0 0 10px rgba(76, 175, 80, 0); } 100% { box-shadow: 0 0 0 0 rgba(76, 175, 80, 0); } } """ # Gradio UI with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.HTML("""
Advanced AI Model for Text-to-Image, Image Editing, and Image Understanding