Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) iMED | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
os.environ['TORCHDYNAMO_DISABLE'] = "1" | |
import sys | |
import copy | |
import re | |
from argparse import ArgumentParser | |
from threading import Thread | |
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, AutoModelForImageTextToText, TextIteratorStreamer | |
import numpy as np | |
from PIL import Image | |
import io | |
import base64 | |
import warnings | |
warnings.filterwarnings('ignore') | |
# Try to import spaces, define placeholder decorator if failed | |
try: | |
import spaces | |
HAS_SPACES = True | |
except ImportError: | |
HAS_SPACES = False | |
class spaces: | |
def GPU(func=None, **kwargs): | |
if func: | |
return func | |
return lambda f: f | |
# Check if GPU is available | |
HAS_GPU = torch.cuda.is_available() | |
# Try to install flash-attn (only in GPU environment) | |
if HAS_GPU: | |
try: | |
import subprocess | |
subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', | |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
shell=True, | |
capture_output=True, | |
timeout=30) | |
import flash_attn | |
HAS_FLASH_ATTN = True | |
except Exception as e: | |
print(f"Flash Attention installation failed: {e}") | |
HAS_FLASH_ATTN = False | |
else: | |
HAS_FLASH_ATTN = False | |
HAS_FLASH_ATTN = False | |
# Default model checkpoint path | |
DEFAULT_CKPT_PATH = 'qiuxi337/IntrinSight-4B' | |
# Default system prompt | |
DEFAULT_SYSTEM_PROMPT = ( | |
"A conversation between user and assistant. The user asks a question, and the assistant solves it. The assistant " | |
"first thinks about the reasoning process in the mind and then provides the user with the answer. " | |
"The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., " | |
"<think> reasoning process here </think><answer> answer here </answer>." | |
) | |
# CSS styles for interface beautification | |
CUSTOM_CSS = """ | |
/* Main container styles */ | |
.container { | |
max-width: 1400px; | |
margin: 0 auto; | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
} | |
/* Title styles */ | |
.main-title { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
font-size: 3em; | |
font-weight: bold; | |
text-align: center; | |
margin-bottom: 10px; | |
} | |
.sub-title { | |
text-align: center; | |
color: #666; | |
font-size: 1.2em; | |
margin-bottom: 30px; | |
} | |
/* Chatbot styles */ | |
.control-height { | |
border-radius: 15px; | |
border: 1px solid #e0e0e0; | |
box-shadow: 0 2px 10px rgba(0,0,0,0.1); | |
} | |
/* Button styles */ | |
.custom-button { | |
border-radius: 8px; | |
font-weight: 500; | |
transition: all 0.3s ease; | |
} | |
.custom-button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 4px 12px rgba(0,0,0,0.15); | |
} | |
/* Input box styles */ | |
textarea { | |
border-radius: 10px !important; | |
border: 1px solid #d0d0d0 !important; | |
padding: 10px !important; | |
font-size: 14px !important; | |
} | |
textarea:focus { | |
border-color: #667eea !important; | |
box-shadow: 0 0 0 2px rgba(102, 126, 234, 0.1) !important; | |
} | |
/* File upload area styles */ | |
.file-upload-area { | |
border: 2px dashed #667eea; | |
border-radius: 10px; | |
padding: 20px; | |
text-align: center; | |
background: linear-gradient(135deg, rgba(102, 126, 234, 0.05) 0%, rgba(118, 75, 162, 0.05) 100%); | |
transition: all 0.3s ease; | |
} | |
.file-upload-area:hover { | |
border-color: #764ba2; | |
background: linear-gradient(135deg, rgba(102, 126, 234, 0.1) 0%, rgba(118, 75, 162, 0.1) 100%); | |
} | |
/* Image preview styles */ | |
.image-container { | |
display: flex; | |
flex-wrap: wrap; | |
gap: 10px; | |
margin: 10px 0; | |
} | |
.image-preview { | |
width: 100px; | |
height: 100px; | |
object-fit: cover; | |
border-radius: 8px; | |
border: 2px solid #e0e0e0; | |
} | |
/* Status indicator */ | |
.status-indicator { | |
display: inline-block; | |
padding: 5px 15px; | |
border-radius: 20px; | |
font-size: 12px; | |
font-weight: 500; | |
margin-left: 10px; | |
} | |
.gpu-status { | |
background-color: #4caf50; | |
color: white; | |
} | |
.cpu-status { | |
background-color: #ff9800; | |
color: white; | |
} | |
/* Parameter section styles */ | |
.parameter-section { | |
background: #f5f5f5; | |
border-radius: 10px; | |
padding: 15px; | |
margin-bottom: 15px; | |
} | |
.parameter-title { | |
font-weight: bold; | |
color: #333; | |
margin-bottom: 10px; | |
font-size: 1.1em; | |
} | |
""" | |
def _get_args(): | |
"""Parse command line arguments""" | |
parser = ArgumentParser() | |
parser.add_argument('-c', '--checkpoint-path', | |
type=str, | |
default=DEFAULT_CKPT_PATH, | |
help='Checkpoint name or path, default to %(default)r') | |
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only') | |
parser.add_argument('--share', | |
action='store_true', | |
default=False, | |
help='Create a publicly shareable link for the interface.') | |
parser.add_argument('--inbrowser', | |
action='store_true', | |
default=False, | |
help='Automatically launch the interface in a new tab on the default browser.') | |
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.') | |
parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Demo server name.') | |
args = parser.parse_args() | |
return args | |
def encode_image_pil(image_path): | |
"""Encode image to base64 using PIL""" | |
try: | |
if isinstance(image_path, str): | |
img = Image.open(image_path) | |
elif isinstance(image_path, np.ndarray): | |
img = Image.fromarray(image_path) | |
elif isinstance(image_path, Image.Image): | |
img = image_path | |
else: | |
print(f"Unsupported image type: {type(image_path)}") | |
return None | |
if img.mode not in ('RGB', 'RGBA'): | |
img = img.convert('RGB') | |
max_size = (1024, 1024) | |
img.thumbnail(max_size, Image.Resampling.LANCZOS) | |
buffered = io.BytesIO() | |
img.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
except Exception as e: | |
print(f"Error encoding image: {e}") | |
return None | |
def _load_model_processor(args): | |
"""Intelligently load model, automatically choose CPU or GPU based on environment""" | |
global HAS_GPU, HAS_FLASH_ATTN | |
use_gpu = HAS_GPU and not args.cpu_only | |
device = 'cuda' if use_gpu else 'cpu' | |
print(f"{'='*50}") | |
print(f"π Loading model: {args.checkpoint_path}") | |
print(f"π± Device: {'GPU (CUDA)' if use_gpu else 'CPU'}") | |
print(f"β‘ Flash Attention: {'Enabled' if (use_gpu and HAS_FLASH_ATTN) else 'Disabled'}") | |
print(f"{'='*50}") | |
model_kwargs = { | |
'pretrained_model_name_or_path': args.checkpoint_path, | |
'torch_dtype': torch.bfloat16 if use_gpu else torch.float32, | |
} | |
if use_gpu and HAS_FLASH_ATTN: | |
model_kwargs['attn_implementation'] = 'flash_attention_2' | |
if use_gpu: | |
model_kwargs['device_map'] = 'auto' | |
else: | |
model_kwargs['device_map'] = None | |
model_kwargs['low_cpu_mem_usage'] = True | |
try: | |
model = AutoModelForImageTextToText.from_pretrained(**model_kwargs) | |
model.eval() | |
# Note: even with device_map='auto', we might need to move a CPU-only model explicitly | |
if not use_gpu: | |
model = model.to(device) | |
except Exception as e: | |
print(f"β οΈ Failed to load model with optimal settings: {e}") | |
print("π Falling back to CPU mode...") | |
model_kwargs = { | |
'pretrained_model_name_or_path': args.checkpoint_path, | |
'torch_dtype': torch.float32, | |
'device_map': None, | |
'low_cpu_mem_usage': True | |
} | |
model = AutoModelForImageTextToText.from_pretrained(**model_kwargs) | |
model = model.to('cpu') | |
model.eval() | |
use_gpu = False | |
device = 'cpu' | |
processor = AutoProcessor.from_pretrained(args.checkpoint_path) | |
print(f"β Model loaded successfully on {device}") | |
return model, processor, device | |
def _parse_text(text): | |
"""Parse text for display formatting""" | |
if text is None: | |
return "" | |
text = str(text) | |
lines = text.split('\n') | |
lines = [line for line in lines if line != ''] | |
count = 0 | |
for i, line in enumerate(lines): | |
if "<think>" in line: | |
line = line.replace("<think>", "**Reasoning Process**:\n") | |
if "</think>" in line: | |
line = line.replace("</think>", "") | |
if "<answer>" in line: | |
line = line.replace("<answer>", "**Final Answer**:\n") | |
if "</answer>" in line: | |
line = line.replace("</answer>", "") | |
if '```' in line: | |
count += 1 | |
items = line.split('`') | |
if count % 2 == 1: | |
lines[i] = f'<pre><code class="language-{items[-1]}">' | |
else: | |
lines[i] = '<br></code></pre>' | |
else: | |
if i > 0: | |
if count % 2 == 1: | |
line = line.replace('`', r'\`') | |
line = line.replace('<', '<') | |
line = line.replace('>', '>') | |
line = line.replace(' ', ' ') | |
line = line.replace('*', '*') | |
line = line.replace('_', '_') | |
line = line.replace('-', '-') | |
line = line.replace('.', '.') | |
line = line.replace('!', '!') | |
line = line.replace('(', '(') | |
line = line.replace(')', ')') | |
line = line.replace('$', '$') | |
lines[i] = '<br>' + line | |
text = ''.join(lines) | |
return text | |
def _remove_image_special(text): | |
"""Remove special image tags from text""" | |
if text is None: return "" | |
text = text.replace('<ref>', '').replace('</ref>', '') | |
return re.sub(r'<box>.*?(</box>|$)', '', text) | |
def _gc(): | |
"""Garbage collection to free memory""" | |
import gc | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
def _transform_messages(original_messages, system_prompt): | |
"""Transform messages with custom system prompt""" | |
transformed_messages = [{"role": "system", "content": [{"type": "text", "text":system_prompt}]}] | |
for message in original_messages: | |
new_content = [] | |
for item in message['content']: | |
if 'image' in item: | |
new_content.append({'type': 'image', 'image': item['image']}) | |
elif 'text' in item: | |
new_content.append({'type': 'text', 'text': item['text']}) | |
if new_content: | |
transformed_messages.append({'role': message['role'], 'content': new_content}) | |
return transformed_messages | |
def normalize_task_history_item(item): | |
"""Normalize items in task_history to a dictionary format""" | |
if isinstance(item, dict): | |
return {'text': item.get('text', ''), 'images': item.get('images', []), 'response': item.get('response', None)} | |
elif isinstance(item, (list, tuple)) and len(item) >= 2: | |
query, response = item[0], item[1] | |
if isinstance(query, (list, tuple)): | |
return {'text': '', 'images': list(query), 'response': response} | |
else: | |
return {'text': str(query) if query else '', 'images': [], 'response': response} | |
else: | |
return {'text': str(item) if item else '', 'images': [], 'response': None} | |
def _launch_demo(args, model, processor, device): | |
"""Launch the Gradio demo interface""" | |
def call_local_model(model, processor, messages, system_prompt, temperature, top_p, max_tokens): | |
"""Call the local model with streaming response""" | |
messages = _transform_messages(messages, system_prompt) | |
inputs = processor.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_dict=True, | |
return_tensors="pt" | |
) | |
# ==================================================================== | |
# THE FINAL, ROBUST FIX for all environments (CUDA, ZeroGPU, CPU) | |
# We must move the input tensors to the correct device. | |
# However, to be compatible with ZeroGPU's `torch.compile`, we must use | |
# a string ('cuda' or 'cpu') instead of a `torch.device` object. | |
# The `device` variable (a string) is passed in from the parent scope. | |
# This prevents both the "device mismatch" error and the "ConstantVariable" error. | |
# ==================================================================== | |
inputs = inputs.to(device) | |
# ==================================================================== | |
tokenizer = processor.tokenizer | |
streamer = TextIteratorStreamer(tokenizer, timeout=2000.0, skip_prompt=True, skip_special_tokens=True) | |
gen_kwargs = { | |
'max_new_tokens': max_tokens, "do_sample": True, "temperature": temperature, | |
"top_p": top_p, "top_k": 20, 'streamer': streamer, **inputs | |
} | |
with torch.inference_mode(): | |
thread = Thread(target=model.generate, kwargs=gen_kwargs) | |
thread.start() | |
generated_text = '' | |
for new_text in streamer: | |
generated_text += new_text | |
display_text = generated_text | |
if "<think>" in display_text: display_text = display_text.replace("<think>", "**Reasoning Process**:\n") | |
if "</think>" in display_text: display_text = display_text.replace("</think>", "\n") | |
if "<answer>" in display_text: display_text = display_text.replace("<answer>", "**Final Answer**:\n") | |
if "</answer>" in display_text: display_text = display_text.replace("</answer>", "") | |
yield display_text, generated_text | |
def predict(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens): | |
if not _chatbot or not task_history: | |
yield _chatbot | |
return | |
chat_query = _chatbot[-1][0] | |
last_item = normalize_task_history_item(task_history[-1]) | |
if not chat_query and not last_item['text'] and not last_item['images']: | |
_chatbot.pop() | |
task_history.pop() | |
yield _chatbot | |
return | |
print(f'User query: {last_item}') | |
history_cp = [normalize_task_history_item(item) for item in copy.deepcopy(task_history)] | |
full_response_raw = '' | |
messages = [] | |
for i, item in enumerate(history_cp): | |
content = [] | |
if item['images']: | |
for img_path in item['images']: | |
if img_path: | |
encoded_img = encode_image_pil(img_path) | |
if encoded_img: content.append({'image': encoded_img}) | |
if item['text']: content.append({'text': str(item['text'])}) | |
if item['response'] is None: | |
if content: messages.append({'role': 'user', 'content': content}) | |
else: | |
if content: messages.append({'role': 'user', 'content': content}) | |
messages.append({'role': 'assistant', 'content': [{'text': str(item['response'])}]}) | |
try: | |
for response_display, response_raw in call_local_model(model, processor, messages, system_prompt, temperature, top_p, max_tokens): | |
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response_display))) | |
yield _chatbot | |
full_response_raw = response_raw | |
task_history[-1]['response'] = full_response_raw | |
print(f'Assistant: {full_response_raw}') | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
import traceback | |
traceback.print_exc() | |
error_msg = f"Error: {str(e)}" | |
_chatbot[-1] = (_parse_text(chat_query), error_msg) | |
task_history[-1]['response'] = error_msg | |
yield _chatbot | |
def regenerate(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens): | |
if not task_history or not _chatbot: | |
yield _chatbot | |
return | |
last_item = normalize_task_history_item(task_history[-1]) | |
if last_item['response'] is None: | |
yield _chatbot | |
return | |
last_item['response'] = None | |
task_history[-1] = last_item | |
_chatbot.pop(-1) | |
display_message_parts = [] | |
if last_item['images']: display_message_parts.append(f"[Uploaded {len(last_item['images'])} images]") | |
if last_item['text']: display_message_parts.append(last_item['text']) | |
display_message = " ".join(display_message_parts) | |
_chatbot.append([_parse_text(display_message), None]) | |
for updated_chatbot in predict(_chatbot, task_history, system_prompt, temperature, top_p, max_tokens): | |
yield updated_chatbot | |
def add_text_and_files(history, task_history, text, files): | |
history = history if history is not None else [] | |
task_history = task_history if task_history is not None else [] | |
has_text = text and text.strip() | |
has_files = files and len(files) > 0 | |
if not has_text and not has_files: | |
return history, task_history, text, files | |
display_parts, file_paths = [], [] | |
if has_files: | |
for file in files: | |
if file and hasattr(file, 'name'): | |
file_paths.append(file.name) | |
if file_paths: | |
display_parts.append(f"[Uploaded {len(file_paths)} images]") | |
if has_text: | |
display_parts.append(text) | |
display_message = " ".join(display_parts) | |
history.append([_parse_text(display_message), None]) | |
task_history.append({'text': text if has_text else '', 'images': file_paths, 'response': None}) | |
return history, task_history, '', None | |
def reset_state(): | |
_gc() | |
return [], [], None | |
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo: | |
gr.HTML(f""" | |
<div class="container"> | |
<h1 class="main-title">IntrinSight Assistant</h1> | |
<p class="sub-title"> | |
Powered by IntrinSight-4B Model | |
<span class="status-indicator {'gpu-status' if device == 'cuda' else 'cpu-status'}"> | |
{'π GPU Mode' if device == 'cuda' else 'π» CPU Mode'} | |
</span> | |
</p> | |
</div> | |
""") | |
task_history = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot( | |
label='IntrinSight-4B Chat Interface', elem_classes='control-height', height=600, | |
avatar_images=(None, "https://em-content.zobj.net/thumbs/240/twitter/348/robot_1f916.png") | |
) | |
with gr.Row(): | |
query = gr.Textbox(lines=3, label='π¬ Message Input', placeholder="Enter your question here...", elem_classes="custom-input") | |
with gr.Row(): | |
addfile_btn = gr.File( | |
label="πΈ Upload Images (Drag & Drop Supported, Multiple Selection)", file_count="multiple", | |
file_types=["image"], elem_classes="file-upload-area" | |
) | |
with gr.Row(): | |
submit_btn = gr.Button('π Send', variant="primary", elem_classes="custom-button") | |
regen_btn = gr.Button('π Regenerate', variant="secondary", elem_classes="custom-button") | |
empty_bin = gr.Button('ποΈ Clear History', variant="stop", elem_classes="custom-button") | |
with gr.Column(scale=2): | |
with gr.Group(elem_classes="parameter-section"): | |
gr.Markdown("### βοΈ System Configuration") | |
system_prompt = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=5, placeholder="Enter system prompt here...") | |
with gr.Group(elem_classes="parameter-section"): | |
gr.Markdown("### ποΈ Generation Parameters") | |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature (Creativity)", info="Higher values make output more random") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (Nucleus Sampling)", info="Cumulative probability for token selection") | |
max_tokens = gr.Slider(minimum=256, maximum=16384, value=8192, step=256, label="Max Tokens", info="Maximum number of tokens to generate") | |
gr.Markdown(f""" | |
### π Instructions | |
**Basic Usage:** | |
- **Text Chat**: Enter your question and click Send | |
- **Image Upload**: Drag & drop or select multiple images | |
- **Mixed Input**: Upload images and enter text, then click Send | |
- **Parameters**: Adjust generation settings as needed | |
**Performance Info:** | |
- Current Mode: **{'GPU Acceleration' if device == 'cuda' else 'CPU Mode'}** | |
- Flash Attention: **{'Enabled' if (device == 'cuda' and HAS_FLASH_ATTN) else 'Disabled'}** | |
- Recommended Image Size: < 1024Γ1024 | |
### β οΈ Disclaimer | |
This demo is subject to the Gemma license agreement. | |
Please do not generate or disseminate harmful content. | |
""") | |
submit_btn.click( | |
add_text_and_files, | |
[chatbot, task_history, query, addfile_btn], | |
[chatbot, task_history, query, addfile_btn] | |
).then( | |
predict, | |
[chatbot, task_history, system_prompt, temperature, top_p, max_tokens], | |
[chatbot], | |
show_progress="full" | |
) | |
empty_bin.click(reset_state, outputs=[chatbot, task_history, addfile_btn], show_progress=True) | |
regen_btn.click( | |
regenerate, | |
[chatbot, task_history, system_prompt, temperature, top_p, max_tokens], | |
[chatbot], | |
show_progress="full" | |
) | |
query.submit( | |
add_text_and_files, | |
[chatbot, task_history, query, addfile_btn], | |
[chatbot, task_history, query, addfile_btn] | |
).then( | |
predict, | |
[chatbot, task_history, system_prompt, temperature, top_p, max_tokens], | |
[chatbot], | |
show_progress="full" | |
) | |
demo.queue(max_size=10).launch( | |
share=args.share, inbrowser=args.inbrowser, | |
server_port=args.server_port, server_name=args.server_name, | |
show_error=True | |
) | |
def main(): | |
"""Main entry point""" | |
args = _get_args() | |
model, processor, device = _load_model_processor(args) | |
_launch_demo(args, model, processor, device) | |
if __name__ == '__main__': | |
main() |