# Copyright (c) iMED # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os os.environ['TORCHDYNAMO_DISABLE'] = "1" import sys import copy import re from argparse import ArgumentParser from threading import Thread import gradio as gr import torch from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer import numpy as np from PIL import Image import io import base64 import warnings warnings.filterwarnings('ignore') # Try to import spaces, define placeholder decorator if failed try: import spaces HAS_SPACES = True except ImportError: HAS_SPACES = False class spaces: @staticmethod def GPU(func=None, **kwargs): if func: return func return lambda f: f # Check if GPU is available HAS_GPU = torch.cuda.is_available() # Try to install flash-attn (only in GPU environment) if HAS_GPU: try: import subprocess subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True, capture_output=True, timeout=30) import flash_attn HAS_FLASH_ATTN = True except Exception as e: print(f"Flash Attention installation failed: {e}") HAS_FLASH_ATTN = False else: HAS_FLASH_ATTN = False HAS_FLASH_ATTN = False # Default model checkpoint path DEFAULT_CKPT_PATH = 'qiuxi337/IntrinSight-4B' # Default system prompt DEFAULT_SYSTEM_PROMPT = ( "A conversation between user and assistant. The user asks a question, and the assistant solves it. The assistant " "first thinks about the reasoning process in the mind and then provides the user with the answer. " "The reasoning process and answer are enclosed within and tags, respectively, i.e., " " reasoning process here answer here ." ) # CSS styles for interface beautification CUSTOM_CSS = """ /* Main container styles */ .container { max-width: 1400px; margin: 0 auto; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } /* Title styles */ .main-title { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 3em; font-weight: bold; text-align: center; margin-bottom: 10px; } .sub-title { text-align: center; color: #666; font-size: 1.2em; margin-bottom: 30px; } /* Chatbot styles */ .control-height { border-radius: 15px; border: 1px solid #e0e0e0; box-shadow: 0 2px 10px rgba(0,0,0,0.1); } /* Button styles */ .custom-button { border-radius: 8px; font-weight: 500; transition: all 0.3s ease; } .custom-button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); } /* Input box styles */ textarea { border-radius: 10px !important; border: 1px solid #d0d0d0 !important; padding: 10px !important; font-size: 14px !important; } textarea:focus { border-color: #667eea !important; box-shadow: 0 0 0 2px rgba(102, 126, 234, 0.1) !important; } /* File upload area styles */ .file-upload-area { border: 2px dashed #667eea; border-radius: 10px; padding: 20px; text-align: center; background: linear-gradient(135deg, rgba(102, 126, 234, 0.05) 0%, rgba(118, 75, 162, 0.05) 100%); transition: all 0.3s ease; } .file-upload-area:hover { border-color: #764ba2; background: linear-gradient(135deg, rgba(102, 126, 234, 0.1) 0%, rgba(118, 75, 162, 0.1) 100%); } /* Image preview styles */ .image-container { display: flex; flex-wrap: wrap; gap: 10px; margin: 10px 0; } .image-preview { width: 100px; height: 100px; object-fit: cover; border-radius: 8px; border: 2px solid #e0e0e0; } /* Status indicator */ .status-indicator { display: inline-block; padding: 5px 15px; border-radius: 20px; font-size: 12px; font-weight: 500; margin-left: 10px; } .gpu-status { background-color: #4caf50; color: white; } .cpu-status { background-color: #ff9800; color: white; } /* Parameter section styles */ .parameter-section { background: #f5f5f5; border-radius: 10px; padding: 15px; margin-bottom: 15px; } .parameter-title { font-weight: bold; color: #333; margin-bottom: 10px; font-size: 1.1em; } """ def _get_args(): """Parse command line arguments""" parser = ArgumentParser() parser.add_argument('-c', '--checkpoint-path', type=str, default=DEFAULT_CKPT_PATH, help='Checkpoint name or path, default to %(default)r') parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only') parser.add_argument('--share', action='store_true', default=False, help='Create a publicly shareable link for the interface.') parser.add_argument('--inbrowser', action='store_true', default=False, help='Automatically launch the interface in a new tab on the default browser.') parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.') parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Demo server name.') args = parser.parse_args() return args def encode_image_pil(image_path): """Encode image to base64 using PIL""" try: if isinstance(image_path, str): img = Image.open(image_path) elif isinstance(image_path, np.ndarray): img = Image.fromarray(image_path) elif isinstance(image_path, Image.Image): img = image_path else: print(f"Unsupported image type: {type(image_path)}") return None if img.mode not in ('RGB', 'RGBA'): img = img.convert('RGB') max_size = (1024, 1024) img.thumbnail(max_size, Image.Resampling.LANCZOS) buffered = io.BytesIO() img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode('utf-8') except Exception as e: print(f"Error encoding image: {e}") return None def _load_model_processor(args): """Intelligently load model, automatically choose CPU or GPU based on environment""" global HAS_GPU, HAS_FLASH_ATTN use_gpu = HAS_GPU and not args.cpu_only device = 'cuda' if use_gpu else 'cpu' print(f"{'='*50}") print(f"🚀 Loading model: {args.checkpoint_path}") print(f"📱 Device: {'GPU (CUDA)' if use_gpu else 'CPU'}") print(f"⚡ Flash Attention: {'Enabled' if (use_gpu and HAS_FLASH_ATTN) else 'Disabled'}") print(f"{'='*50}") model_kwargs = { 'pretrained_model_name_or_path': args.checkpoint_path, 'torch_dtype': torch.bfloat16 if use_gpu else torch.float32, } if use_gpu and HAS_FLASH_ATTN: model_kwargs['attn_implementation'] = 'flash_attention_2' if use_gpu: model_kwargs['device_map'] = 'auto' else: model_kwargs['device_map'] = None model_kwargs['low_cpu_mem_usage'] = True try: model = AutoModelForImageTextToText.from_pretrained(**model_kwargs) model.eval() # Note: even with device_map='auto', we might need to move a CPU-only model explicitly if not use_gpu: model = model.to(device) except Exception as e: print(f"⚠️ Failed to load model with optimal settings: {e}") print("🔄 Falling back to CPU mode...") model_kwargs = { 'pretrained_model_name_or_path': args.checkpoint_path, 'torch_dtype': torch.float32, 'device_map': None, 'low_cpu_mem_usage': True } model = AutoModelForImageTextToText.from_pretrained(**model_kwargs) model = model.to('cpu') model.eval() use_gpu = False device = 'cpu' processor = AutoProcessor.from_pretrained(args.checkpoint_path) print(f"✅ Model loaded successfully on {device}") return model, processor, device def _parse_text(text): """Parse text for display formatting""" if text is None: return "" text = str(text) lines = text.split('\n') lines = [line for line in lines if line != ''] count = 0 for i, line in enumerate(lines): if "" in line: line = line.replace("", "**Reasoning Process**:\n") if "" in line: line = line.replace("", "") if "" in line: line = line.replace("", "**Final Answer**:\n") if "" in line: line = line.replace("", "") if '```' in line: count += 1 items = line.split('`') if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = '
' else: if i > 0: if count % 2 == 1: line = line.replace('`', r'\`') line = line.replace('<', '<') line = line.replace('>', '>') line = line.replace(' ', ' ') line = line.replace('*', '*') line = line.replace('_', '_') line = line.replace('-', '-') line = line.replace('.', '.') line = line.replace('!', '!') line = line.replace('(', '(') line = line.replace(')', ')') line = line.replace('$', '$') lines[i] = '
' + line text = ''.join(lines) return text def _remove_image_special(text): """Remove special image tags from text""" if text is None: return "" text = text.replace('', '').replace('', '') return re.sub(r'.*?(|$)', '', text) def _gc(): """Garbage collection to free memory""" import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def _transform_messages(original_messages, system_prompt): """Transform messages with custom system prompt""" transformed_messages = [{"role": "system", "content": [{"type": "text", "text":system_prompt}]}] for message in original_messages: new_content = [] for item in message['content']: if 'image' in item: new_content.append({'type': 'image', 'image': item['image']}) elif 'text' in item: new_content.append({'type': 'text', 'text': item['text']}) if new_content: transformed_messages.append({'role': message['role'], 'content': new_content}) return transformed_messages def normalize_task_history_item(item): """Normalize items in task_history to a dictionary format""" if isinstance(item, dict): return {'text': item.get('text', ''), 'images': item.get('images', []), 'response': item.get('response', None)} elif isinstance(item, (list, tuple)) and len(item) >= 2: query, response = item[0], item[1] if isinstance(query, (list, tuple)): return {'text': '', 'images': list(query), 'response': response} else: return {'text': str(query) if query else '', 'images': [], 'response': response} else: return {'text': str(item) if item else '', 'images': [], 'response': None} def _launch_demo(args, model, processor, device): """Launch the Gradio demo interface""" def call_local_model(model, processor, messages, system_prompt, temperature, top_p, max_tokens): """Call the local model with streaming response""" messages = _transform_messages(messages, system_prompt) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ) # ==================================================================== # THE FINAL, ROBUST FIX for all environments (CUDA, ZeroGPU, CPU) # We must move the input tensors to the correct device. # However, to be compatible with ZeroGPU's `torch.compile`, we must use # a string ('cuda' or 'cpu') instead of a `torch.device` object. # The `device` variable (a string) is passed in from the parent scope. # This prevents both the "device mismatch" error and the "ConstantVariable" error. # ==================================================================== inputs = inputs.to(device) # ==================================================================== tokenizer = processor.tokenizer streamer = TextIteratorStreamer(tokenizer, timeout=2000.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs = { 'max_new_tokens': max_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": 20, 'streamer': streamer, **inputs } with torch.inference_mode(): thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() generated_text = '' for new_text in streamer: generated_text += new_text display_text = generated_text if "" in display_text: display_text = display_text.replace("", "**Reasoning Process**:\n") if "" in display_text: display_text = display_text.replace("", "\n") if "" in display_text: display_text = display_text.replace("", "**Final Answer**:\n") if "" in display_text: display_text = display_text.replace("", "") yield display_text, generated_text @spaces.GPU(duration=120) def predict(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens): if not _chatbot or not task_history: yield _chatbot return chat_query = _chatbot[-1][0] last_item = normalize_task_history_item(task_history[-1]) if not chat_query and not last_item['text'] and not last_item['images']: _chatbot.pop() task_history.pop() yield _chatbot return print(f'User query: {last_item}') history_cp = [normalize_task_history_item(item) for item in copy.deepcopy(task_history)] full_response_raw = '' messages = [] for i, item in enumerate(history_cp): content = [] if item['images']: for img_path in item['images']: if img_path: encoded_img = encode_image_pil(img_path) if encoded_img: content.append({'image': encoded_img}) if item['text']: content.append({'text': str(item['text'])}) if item['response'] is None: if content: messages.append({'role': 'user', 'content': content}) else: if content: messages.append({'role': 'user', 'content': content}) messages.append({'role': 'assistant', 'content': [{'text': str(item['response'])}]}) try: for response_display, response_raw in call_local_model(model, processor, messages, system_prompt, temperature, top_p, max_tokens): _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response_display))) yield _chatbot full_response_raw = response_raw task_history[-1]['response'] = full_response_raw print(f'Assistant: {full_response_raw}') except Exception as e: print(f"Error during generation: {e}") import traceback traceback.print_exc() error_msg = f"Error: {str(e)}" _chatbot[-1] = (_parse_text(chat_query), error_msg) task_history[-1]['response'] = error_msg yield _chatbot @spaces.GPU(duration=120) def regenerate(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens): if not task_history or not _chatbot: yield _chatbot return last_item = normalize_task_history_item(task_history[-1]) if last_item['response'] is None: yield _chatbot return last_item['response'] = None task_history[-1] = last_item _chatbot.pop(-1) display_message_parts = [] if last_item['images']: display_message_parts.append(f"[Uploaded {len(last_item['images'])} images]") if last_item['text']: display_message_parts.append(last_item['text']) display_message = " ".join(display_message_parts) _chatbot.append([_parse_text(display_message), None]) for updated_chatbot in predict(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens): yield updated_chatbot def add_text_and_files(history, task_history, text, files): history = history if history is not None else [] task_history = task_history if task_history is not None else [] has_text = text and text.strip() has_files = files and len(files) > 0 if not has_text and not has_files: return history, task_history, text, files display_parts, file_paths = [], [] if has_files: for file in files: if file and hasattr(file, 'name'): file_paths.append(file.name) if file_paths: display_parts.append(f"[Uploaded {len(file_paths)} images]") if has_text: display_parts.append(text) display_message = " ".join(display_parts) history.append([_parse_text(display_message), None]) task_history.append({'text': text if has_text else '', 'images': file_paths, 'response': None}) return history, task_history, '', None def reset_state(): _gc() return [], [], None with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo: gr.HTML(f"""

IntrinSight Assistant

Powered by IntrinSight-4B Model {'🚀 GPU Mode' if device == 'cuda' else '💻 CPU Mode'}

""") task_history = gr.State([]) with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( label='IntrinSight-4B Chat Interface', elem_classes='control-height', height=600, avatar_images=(None, "https://em-content.zobj.net/thumbs/240/twitter/348/robot_1f916.png") ) with gr.Row(): query = gr.Textbox(lines=3, label='💬 Message Input', placeholder="Enter your question here...", elem_classes="custom-input") with gr.Row(): addfile_btn = gr.File( label="📸 Upload Images (Drag & Drop Supported, Multiple Selection)", file_count="multiple", file_types=["image"], elem_classes="file-upload-area" ) with gr.Row(): submit_btn = gr.Button('🚀 Send', variant="primary", elem_classes="custom-button") regen_btn = gr.Button('🔄 Regenerate', variant="secondary", elem_classes="custom-button") empty_bin = gr.Button('🗑️ Clear History', variant="stop", elem_classes="custom-button") with gr.Column(scale=2): with gr.Group(elem_classes="parameter-section"): gr.Markdown("### ⚙️ System Configuration") system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=5, placeholder="Enter system prompt here...") with gr.Group(elem_classes="parameter-section"): gr.Markdown("### 🎛️ Generation Parameters") temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature (Creativity)", info="Higher values make output more random") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (Nucleus Sampling)", info="Cumulative probability for token selection") max_tokens = gr.Slider(minimum=256, maximum=16384, value=8192, step=256, label="Max Tokens", info="Maximum number of tokens to generate") gr.Markdown(f""" ### 📋 Instructions **Basic Usage:** - **Text Chat**: Enter your question and click Send - **Image Upload**: Drag & drop or select multiple images - **Mixed Input**: Upload images and enter text, then click Send - **Parameters**: Adjust generation settings as needed **Performance Info:** - Current Mode: **{'GPU Acceleration' if device == 'cuda' else 'CPU Mode'}** - Flash Attention: **{'Enabled' if (device == 'cuda' and HAS_FLASH_ATTN) else 'Disabled'}** - Recommended Image Size: < 1024×1024 ### ⚠️ Disclaimer This demo is subject to the Gemma license agreement. Please do not generate or disseminate harmful content. """) submit_btn.click( add_text_and_files, [chatbot, task_history, query, addfile_btn], [chatbot, task_history, query, addfile_btn] ).then( predict, [chatbot, task_history, system_prompt, temperature, top_p, max_tokens], [chatbot], show_progress="full" ) empty_bin.click(reset_state, outputs=[chatbot, task_history, addfile_btn], show_progress=True) regen_btn.click( regenerate, [chatbot, task_history, system_prompt, temperature, top_p, max_tokens], [chatbot], show_progress="full" ) query.submit( add_text_and_files, [chatbot, task_history, query, addfile_btn], [chatbot, task_history, query, addfile_btn] ).then( predict, [chatbot, task_history, system_prompt, temperature, top_p, max_tokens], [chatbot], show_progress="full" ) demo.queue(max_size=10).launch( share=args.share, inbrowser=args.inbrowser, server_port=args.server_port, server_name=args.server_name, show_error=True ) def main(): """Main entry point""" args = _get_args() model, processor, device = _load_model_processor(args) _launch_demo(args, model, processor, device) if __name__ == '__main__': main()