import spaces import gradio as gr import numpy as np import os import torch import random import subprocess import requests import json subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights from PIL import Image from data.data_utils import add_special_tokens, pil_img2rgb from data.transforms import ImageTransform from inferencer import InterleaveInferencer from modeling.autoencoder import load_ae from modeling.bagel import ( BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel ) from modeling.qwen2 import Qwen2Tokenizer from huggingface_hub import snapshot_download # Get Brave Search API key BSEARCH_API = os.getenv("BSEARCH_API") save_dir = "./model_weights" repo_id = "ByteDance-Seed/BAGEL-7B-MoT" cache_dir = save_dir + "/cache" snapshot_download( cache_dir=cache_dir, local_dir=save_dir, repo_id=repo_id, local_dir_use_symlinks=False, resume_download=True, allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"], ) # Model Initialization model_path = save_dir llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json")) llm_config.qk_norm = True llm_config.tie_word_embeddings = False llm_config.layer_module = "Qwen2MoTDecoderLayer" vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json")) vit_config.rope = False vit_config.num_hidden_layers -= 1 vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors")) config = BagelConfig( visual_gen=True, visual_und=True, llm_config=llm_config, vit_config=vit_config, vae_config=vae_config, vit_max_num_patch_per_side=70, connector_act='gelu_pytorch_tanh', latent_patch_size=2, max_latent_size=64, ) with init_empty_weights(): language_model = Qwen2ForCausalLM(llm_config) vit_model = SiglipVisionModel(vit_config) model = Bagel(language_model, vit_model, config) model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True) tokenizer = Qwen2Tokenizer.from_pretrained(model_path) tokenizer, new_token_ids, _ = add_special_tokens(tokenizer) vae_transform = ImageTransform(1024, 512, 16) vit_transform = ImageTransform(980, 224, 14) # Model Loading and Multi GPU Infernece Preparing device_map = infer_auto_device_map( model, max_memory={i: "80GiB" for i in range(torch.cuda.device_count())}, no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], ) same_device_modules = [ 'language_model.model.embed_tokens', 'time_embedder', 'latent_pos_embed', 'vae2llm', 'llm2vae', 'connector', 'vit_pos_embed' ] if torch.cuda.device_count() == 1: first_device = device_map.get(same_device_modules[0], "cuda:0") for k in same_device_modules: if k in device_map: device_map[k] = first_device else: device_map[k] = "cuda:0" else: first_device = device_map.get(same_device_modules[0]) for k in same_device_modules: if k in device_map: device_map[k] = first_device model = load_checkpoint_and_dispatch( model, checkpoint=os.path.join(model_path, "ema.safetensors"), device_map=device_map, offload_buffers=True, offload_folder="offload", dtype=torch.bfloat16, force_hooks=True, ).eval() # Inferencer Preparing inferencer = InterleaveInferencer( model=model, vae_model=vae_model, tokenizer=tokenizer, vae_transform=vae_transform, vit_transform=vit_transform, new_token_ids=new_token_ids, ) # Brave Search function def brave_search(query): """Perform a web search using Brave Search API.""" if not BSEARCH_API: return None try: headers = { "Accept": "application/json", "X-Subscription-Token": BSEARCH_API } url = "https://api.search.brave.com/res/v1/web/search" params = { "q": query, "count": 5 } response = requests.get(url, headers=headers, params=params) response.raise_for_status() data = response.json() results = [] if "web" in data and "results" in data["web"]: for idx, result in enumerate(data["web"]["results"][:5], 1): title = result.get("title", "No title") url = result.get("url", "") description = result.get("description", "No description") results.append(f"{idx}. {title}\nURL: {url}\n{description}") if results: return "\n\n".join(results) else: return None except Exception as e: print(f"Search error: {str(e)}") return None def enhance_prompt_with_search(prompt, use_search=False): """Enhance prompt with web search results if enabled.""" if not use_search or not BSEARCH_API: return prompt search_results = brave_search(prompt) if search_results: enhanced_prompt = f"{prompt}\n\n[Web Search Context]:\n{search_results}\n\n[Generate based on the above context and original prompt]" return enhanced_prompt return prompt def set_seed(seed): """Set random seeds for reproducibility""" if seed > 0: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False return seed # Text to Image function with thinking option and hyperparameters @spaces.GPU(duration=90) def text_to_image(prompt, use_web_search=False, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4, timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0, cfg_renorm_type="global", max_think_token_n=1024, do_sample=False, text_temperature=0.3, seed=0, image_ratio="1:1"): # Set seed for reproducibility set_seed(seed) # Enhance prompt with search if enabled enhanced_prompt = enhance_prompt_with_search(prompt, use_web_search) if image_ratio == "1:1": image_shapes = (1024, 1024) elif image_ratio == "4:3": image_shapes = (768, 1024) elif image_ratio == "3:4": image_shapes = (1024, 768) elif image_ratio == "16:9": image_shapes = (576, 1024) elif image_ratio == "9:16": image_shapes = (1024, 576) # Set hyperparameters inference_hyper = dict( max_think_token_n=max_think_token_n if show_thinking else 1024, do_sample=do_sample if show_thinking else False, text_temperature=text_temperature if show_thinking else 0.3, cfg_text_scale=cfg_text_scale, cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 timestep_shift=timestep_shift, num_timesteps=num_timesteps, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, image_shapes=image_shapes, ) result = {"text": "", "image": None} # Call inferencer with or without think parameter based on user choice for i in inferencer(text=enhanced_prompt, think=show_thinking, understanding_output=False, **inference_hyper): if type(i) == str: result["text"] += i else: result["image"] = i yield result["image"], result.get("text", None) # Image Understanding function with thinking option and hyperparameters @spaces.GPU(duration=90) def image_understanding(image: Image.Image, prompt: str, use_web_search=False, show_thinking=False, do_sample=False, text_temperature=0.3, max_new_tokens=512): if image is None: return "Please upload an image." if isinstance(image, np.ndarray): image = Image.fromarray(image) image = pil_img2rgb(image) # Enhance prompt with search if enabled enhanced_prompt = enhance_prompt_with_search(prompt, use_web_search) # Set hyperparameters inference_hyper = dict( do_sample=do_sample, text_temperature=text_temperature, max_think_token_n=max_new_tokens, # Set max_length ) result = {"text": "", "image": None} # Use show_thinking parameter to control thinking process for i in inferencer(image=image, text=enhanced_prompt, think=show_thinking, understanding_output=True, **inference_hyper): if type(i) == str: result["text"] += i else: result["image"] = i yield result["text"] # Image Editing function with thinking option and hyperparameters @spaces.GPU(duration=90) def edit_image(image: Image.Image, prompt: str, use_web_search=False, show_thinking=False, cfg_text_scale=4.0, cfg_img_scale=2.0, cfg_interval=0.0, timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0, cfg_renorm_type="text_channel", max_think_token_n=1024, do_sample=False, text_temperature=0.3, seed=0): # Set seed for reproducibility set_seed(seed) if image is None: return "Please upload an image.", "" if isinstance(image, np.ndarray): image = Image.fromarray(image) image = pil_img2rgb(image) # Enhance prompt with search if enabled enhanced_prompt = enhance_prompt_with_search(prompt, use_web_search) # Set hyperparameters inference_hyper = dict( max_think_token_n=max_think_token_n if show_thinking else 1024, do_sample=do_sample if show_thinking else False, text_temperature=text_temperature if show_thinking else 0.3, cfg_text_scale=cfg_text_scale, cfg_img_scale=cfg_img_scale, cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0 timestep_shift=timestep_shift, num_timesteps=num_timesteps, cfg_renorm_min=cfg_renorm_min, cfg_renorm_type=cfg_renorm_type, ) # Include thinking parameter based on user choice result = {"text": "", "image": None} for i in inferencer(image=image, text=enhanced_prompt, think=show_thinking, understanding_output=False, **inference_hyper): if type(i) == str: result["text"] += i else: result["image"] = i yield result["image"], result.get("text", "") # Helper function to load example images def load_example_image(image_path): try: return Image.open(image_path) except Exception as e: print(f"Error loading example image: {e}") return None # Enhanced CSS for visual improvements custom_css = """ /* Modern gradient background */ .gradio-container { background: linear-gradient(135deg, #1e3c72 0%, #2a5298 50%, #3a6fb0 100%); min-height: 100vh; } /* Main container with glassmorphism */ .container { backdrop-filter: blur(10px); background: rgba(255, 255, 255, 0.1); border-radius: 20px; padding: 30px; margin: 20px auto; max-width: 1400px; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.2); } /* Header styling */ h1 { background: linear-gradient(90deg, #ffffff 0%, #e0e0e0 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 3.5em; text-align: center; margin-bottom: 30px; font-weight: 800; text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.3); } /* Tab styling */ .tabs { background: rgba(255, 255, 255, 0.15); border-radius: 15px; padding: 10px; margin-bottom: 20px; } .tab-nav { background: rgba(255, 255, 255, 0.2) !important; border-radius: 10px !important; padding: 5px !important; } .tab-nav button { background: transparent !important; color: white !important; border: none !important; padding: 10px 20px !important; margin: 0 5px !important; border-radius: 8px !important; font-weight: 600 !important; transition: all 0.3s ease !important; } .tab-nav button.selected { background: rgba(255, 255, 255, 0.3) !important; box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2) !important; } .tab-nav button:hover { background: rgba(255, 255, 255, 0.25) !important; } /* Input field styling */ .textbox, .image-container { background: rgba(255, 255, 255, 0.95) !important; border: 2px solid rgba(255, 255, 255, 0.3) !important; border-radius: 12px !important; padding: 15px !important; color: #333 !important; font-size: 16px !important; transition: all 0.3s ease !important; } .textbox:focus { border-color: #3a6fb0 !important; box-shadow: 0 0 20px rgba(58, 111, 176, 0.4) !important; } /* Button styling */ .primary { background: linear-gradient(135deg, #4CAF50 0%, #45a049 100%) !important; color: white !important; border: none !important; padding: 12px 30px !important; border-radius: 10px !important; font-weight: 600 !important; font-size: 16px !important; cursor: pointer !important; transition: all 0.3s ease !important; box-shadow: 0 4px 15px rgba(76, 175, 80, 0.3) !important; } .primary:hover { transform: translateY(-2px) !important; box-shadow: 0 6px 20px rgba(76, 175, 80, 0.4) !important; } /* Checkbox styling */ .checkbox-group { background: rgba(255, 255, 255, 0.1) !important; padding: 10px 15px !important; border-radius: 8px !important; margin: 10px 0 !important; } .checkbox-group label { color: white !important; font-weight: 500 !important; } /* Accordion styling */ .accordion { background: rgba(255, 255, 255, 0.1) !important; border-radius: 12px !important; margin: 15px 0 !important; border: 1px solid rgba(255, 255, 255, 0.2) !important; } .accordion-header { background: rgba(255, 255, 255, 0.15) !important; color: white !important; padding: 12px 20px !important; border-radius: 10px !important; font-weight: 600 !important; } /* Slider styling */ .slider { background: rgba(255, 255, 255, 0.2) !important; border-radius: 5px !important; } .slider .handle { background: white !important; border: 3px solid #3a6fb0 !important; } /* Image output styling */ .image-frame { border-radius: 15px !important; overflow: hidden !important; box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3) !important; background: rgba(255, 255, 255, 0.1) !important; padding: 10px !important; } /* Footer links */ a { color: #64b5f6 !important; text-decoration: none !important; font-weight: 500 !important; transition: color 0.3s ease !important; } a:hover { color: #90caf9 !important; } /* Web search info box */ .web-search-info { background: linear-gradient(135deg, rgba(255, 193, 7, 0.2) 0%, rgba(255, 152, 0, 0.2) 100%); border: 2px solid rgba(255, 193, 7, 0.5); border-radius: 10px; padding: 15px; margin: 10px 0; color: white; } .web-search-info h4 { margin: 0 0 10px 0; color: #ffd54f; font-size: 1.2em; } .web-search-info p { margin: 5px 0; font-size: 0.95em; line-height: 1.4; } /* Loading animation */ .generating { border-color: #4CAF50 !important; animation: pulse 2s infinite !important; } @keyframes pulse { 0% { box-shadow: 0 0 0 0 rgba(76, 175, 80, 0.7); } 70% { box-shadow: 0 0 0 10px rgba(76, 175, 80, 0); } 100% { box-shadow: 0 0 0 0 rgba(76, 175, 80, 0); } } """ # Gradio UI with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.HTML("""

πŸ₯― BAGEL - Bootstrapping Aligned Generation with Exponential Learning

Advanced AI Model for Text-to-Image, Image Editing, and Image Understanding

""") with gr.Tab("πŸ“ Text to Image"): txt_input = gr.Textbox( label="Prompt", value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere.", lines=3 ) with gr.Row(): use_web_search = gr.Checkbox( label="πŸ” Enable Web Search", value=False, info="Search the web for current information to enhance your prompt" ) show_thinking = gr.Checkbox(label="πŸ’­ Show Thinking Process", value=False) # Web Search Information Box web_search_info = gr.HTML(""" """, visible=False) # Show/hide web search info based on checkbox def toggle_search_info(use_search): return gr.update(visible=use_search) use_web_search.change(toggle_search_info, inputs=[use_web_search], outputs=[web_search_info]) # Add hyperparameter controls in an accordion with gr.Accordion("βš™οΈ Advanced Settings", open=False): # ε‚ζ•°δΈ€ζŽ’δΈ€δΈͺεΈƒε±€ with gr.Group(): with gr.Row(): seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, label="Seed", info="0 for random seed, positive for reproducible results") image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"], value="1:1", label="Image Ratio", info="The longer size is fixed to 1024") with gr.Row(): cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True, label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)") cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)") with gr.Row(): cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], value="global", label="CFG Renorm Type", info="If the genrated image is blurry, use 'global'") cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="CFG Renorm Min", info="1.0 disables CFG-Renorm") with gr.Row(): num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True, label="Timesteps", info="Total denoising steps") timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True, label="Timestep Shift", info="Higher values for layout, lower for details") # Thinking parameters in a single row thinking_params = gr.Group(visible=False) with thinking_params: with gr.Row(): do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation") max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True, label="Max Think Tokens", info="Maximum number of tokens for thinking") text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True, label="Temperature", info="Controls randomness in text generation") thinking_output = gr.Textbox(label="Thinking Process", visible=False) img_output = gr.Image(label="Generated Image", elem_classes=["image-frame"]) gen_btn = gr.Button("🎨 Generate Image", variant="primary", size="lg") # Dynamically show/hide thinking process box and parameters def update_thinking_visibility(show): return gr.update(visible=show), gr.update(visible=show) show_thinking.change( fn=update_thinking_visibility, inputs=[show_thinking], outputs=[thinking_output, thinking_params] ) gr.on( triggers=[gen_btn.click, txt_input.submit], fn=text_to_image, inputs=[ txt_input, use_web_search, show_thinking, cfg_text_scale, cfg_interval, timestep_shift, num_timesteps, cfg_renorm_min, cfg_renorm_type, max_think_token_n, do_sample, text_temperature, seed, image_ratio ], outputs=[img_output, thinking_output] ) with gr.Tab("πŸ–ŒοΈ Image Edit"): with gr.Row(): with gr.Column(scale=1): edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'), elem_classes=["image-frame"]) edit_prompt = gr.Textbox( label="Edit Prompt", value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes.", lines=2 ) with gr.Column(scale=1): edit_image_output = gr.Image(label="Edited Result", elem_classes=["image-frame"]) edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False) with gr.Row(): edit_use_web_search = gr.Checkbox( label="πŸ” Enable Web Search", value=False, info="Search for references and context to improve editing" ) edit_show_thinking = gr.Checkbox(label="πŸ’­ Show Thinking Process", value=False) # Add hyperparameter controls in an accordion with gr.Accordion("βš™οΈ Advanced Settings", open=False): with gr.Group(): with gr.Row(): edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True, label="Seed", info="0 for random seed, positive for reproducible results") edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True, label="CFG Text Scale", info="Controls how strongly the model follows the text prompt") with gr.Row(): edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True, label="CFG Image Scale", info="Controls how much the model preserves input image details") edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)") with gr.Row(): edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"], value="text_channel", label="CFG Renorm Type", info="If the genrated image is blurry, use 'global") edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="CFG Renorm Min", info="1.0 disables CFG-Renorm") with gr.Row(): edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True, label="Timesteps", info="Total denoising steps") edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True, label="Timestep Shift", info="Higher values for layout, lower for details") # Thinking parameters in a single row edit_thinking_params = gr.Group(visible=False) with edit_thinking_params: with gr.Row(): edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation") edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True, label="Max Think Tokens", info="Maximum number of tokens for thinking") edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True, label="Temperature", info="Controls randomness in text generation") edit_btn = gr.Button("✏️ Apply Edit", variant="primary", size="lg") # Dynamically show/hide thinking process box for editing def update_edit_thinking_visibility(show): return gr.update(visible=show), gr.update(visible=show) edit_show_thinking.change( fn=update_edit_thinking_visibility, inputs=[edit_show_thinking], outputs=[edit_thinking_output, edit_thinking_params] ) gr.on( triggers=[edit_btn.click, edit_prompt.submit], fn=edit_image, inputs=[ edit_image_input, edit_prompt, edit_use_web_search, edit_show_thinking, edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval, edit_timestep_shift, edit_num_timesteps, edit_cfg_renorm_min, edit_cfg_renorm_type, edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed ], outputs=[edit_image_output, edit_thinking_output] ) with gr.Tab("πŸ–ΌοΈ Image Understanding"): with gr.Row(): with gr.Column(scale=1): img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'), elem_classes=["image-frame"]) understand_prompt = gr.Textbox( label="Question", value="Can someone explain what's funny about this meme??", lines=2 ) with gr.Column(scale=1): txt_output = gr.Textbox(label="AI Response", lines=20) with gr.Row(): understand_use_web_search = gr.Checkbox( label="πŸ” Enable Web Search", value=False, info="Search for context and references to better understand the image" ) understand_show_thinking = gr.Checkbox(label="πŸ’­ Show Thinking Process", value=False) # Add hyperparameter controls in an accordion with gr.Accordion("βš™οΈ Advanced Settings", open=False): with gr.Row(): understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation") understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True, label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)") understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True, label="Max New Tokens", info="Maximum length of generated text, including potential thinking") img_understand_btn = gr.Button("πŸ” Analyze Image", variant="primary", size="lg") gr.on( triggers=[img_understand_btn.click, understand_prompt.submit], fn=image_understanding, inputs=[ img_input, understand_prompt, understand_use_web_search, understand_show_thinking, understand_do_sample, understand_text_temperature, understand_max_new_tokens ], outputs=txt_output ) demo.launch(share=True)