import os
import spaces
import re
import random
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
from threading import Thread
from agent_tools import RestorationToolkit
# download tools ckpts
from huggingface_hub import hf_hub_download, list_files_info
# 目标目录
target_dir = "JarvisIR/checkpoints/agent_tools"
os.makedirs(target_dir, exist_ok=True)
# 获取文件列表
files = list_files_info(repo_id="LYL1015/JarvisIR", repo_type="model")
for file in files:
if file.path.startswith("agent_tools/checkpoints/"):
rel_path = os.path.relpath(file.path, "agent_tools/checkpoints")
local_path = os.path.join(target_dir, rel_path)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
hf_hub_download(
repo_id="LYL1015/JarvisIR",
filename=file.path,
local_dir=target_dir,
local_dir_use_symlinks=False,
force_filename=local_path
)
# Model configuration
# XXX: Path to the fine-tuned LLaVA model
model_id = "LYL1015/JarvisIR"
# Available image restoration tasks and their corresponding models
all_tasks = " {denoise: [scunet, restormer], lighten: [retinexformer_fivek, hvicidnet, lightdiff], \
derain: [idt, turbo_rain, s2former], defog:[ridcp, kanet], \
desnow:[turbo_snow, snowmaster], super_resolution: [real_esrgan], \
}"
# Various prompt templates for querying the LLM about image degradation and restoration tasks
prompts_query2 = [
f"Considering the image's degradation, suggest the required tasks with explanations, and identify suitable tools for each task. Options for tasks and tools include: {all_tasks}.",
f"Given the image's degradation, outline the essential tasks along with justifications, and choose the appropriate tools for each task from the following options: {all_tasks}.",
f"Please specify the tasks required due to the image's degradation, explain the reasons, and select relevant tools for each task from the provided options: {all_tasks}.",
f"Based on the image degradation, determine the necessary tasks and their reasons, along with the appropriate tools for each task. Choose from these options: {all_tasks}.",
f"Identify the tasks required to address the image's degradation, including the reasons for each, and select tools from the options: {all_tasks}.",
f"Considering the degradation observed, list the tasks needed and their justifications, then pick the most suitable tools for each task from these options: {all_tasks}.",
f"Evaluate the image degradation, and based on that, provide the necessary tasks and reasons, along with tools chosen from the options: {all_tasks}.",
f"With respect to the image degradation, outline the tasks needed and explain why, selecting tools from the following list: {all_tasks}.",
f"Given the level of degradation in the image, specify tasks to address it, include reasons, and select tools for each task from: {all_tasks}.",
f"Examine the image's degradation, propose relevant tasks and their explanations, and identify tools from the options provided: {all_tasks}.",
f"Based on observed degradation, detail the tasks required, explain your choices, and select tools from these options: {all_tasks}.",
f"Using the image's degradation as a guide, list the necessary tasks, include explanations, and pick tools from the provided choices: {all_tasks}.",
f"Assess the image degradation, provide the essential tasks and reasons, and select the appropriate tools for each task from the options: {all_tasks}.",
f"According to the image's degradation, determine which tasks are necessary and why, choosing tools for each task from: {all_tasks}.",
f"Observe the degradation in the image, specify the needed tasks with justifications, and select appropriate tools from: {all_tasks}.",
f"Taking the image degradation into account, specify tasks needed, provide reasons, and choose tools from the following: {all_tasks}.",
f"Consider the image's degradation level, outline the tasks necessary, provide reasoning, and select suitable tools from: {all_tasks}.",
f"Evaluate the degradation in the image, identify tasks required, explain your choices, and pick tools from: {all_tasks}.",
f"Analyze the image degradation and suggest tasks with justifications, choosing the best tools from these options: {all_tasks}.",
f"Review the image degradation, and based on it, specify tasks needed, provide reasons, and select tools for each task from: {all_tasks}."
]
# Initialize models
print("Loading LLM model...")
# Initialize the image restoration toolkit
tool_engine = RestorationToolkit(score_weight=[0,0,0,0,0])
# Load the LLaVA model in half precision to reduce memory usage
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
subfolder="pretrained/preview", # 关键参数:指定子目录
torch_dtype=torch.float16,
device_map="auto", # 比 .to(0) 更智能的设备分配
low_cpu_mem_usage=True
)
processor = AutoProcessor.from_pretrained(model_id)
print("Loading tool engine...")
def parse_llm_response(response):
"""
Parse the LLM response to extract reason and answer sections
Args:
response (str): The raw response from the LLM
Returns:
tuple: (reason, answer) extracted from the response
"""
reason_match = re.search(r'(.*?)', response, re.DOTALL)
answer_match = re.search(r'(.*?)', response, re.DOTALL)
reason = reason_match.group(1).strip() if reason_match else "No reasoning provided"
answer = answer_match.group(1).strip() if answer_match else "No answer provided"
return reason, answer
def extract_models_from_answer(answer):
"""
Extract model names from the answer string using regex
Args:
answer (str): The answer string containing model recommendations
Returns:
list: List of extracted model names
"""
# Pattern to match [type:xxx]:(model:xxx)
pattern = r'\[type:[^\]]+\]:\(model:([^)]+)\)'
models = re.findall(pattern, answer)
return models
def beautify_recommended_actions(answer, models):
"""
Format the LLM's recommendations in a more visually appealing way
Args:
answer (str): The raw answer from LLM
models (list): List of extracted model names
Returns:
str: Beautified display of recommendations
"""
# Task type to emoji mapping for visual enhancement
task_icons = {
'denoise': '🧹',
'lighten': '💡',
'derain': '🌧️',
'defog': '🌫️',
'desnow': '❄️',
'super_resolution': '🔍'
}
# Parse the answer to extract tasks and models
pattern = r'\[type:([^\]]+)\]:\(model:([^)]+)\)'
matches = re.findall(pattern, answer)
if not matches:
return f"**🎯 Recommended Actions:**\n\n{answer}\n\n**Extracted Models:** {', '.join(models) if models else 'None'}"
# Create beautified display
beautified = "**🎯 Recommended Actions:**\n"
beautified += "> "
# Create horizontal flow of actions
action_parts = []
for task_type, model_name in matches:
task_type = task_type.strip()
model_name = model_name.strip()
# Get icon for task type
icon = task_icons.get(task_type, '🔧')
# Format task name (capitalize and replace underscores)
task_display = task_type.title().replace('_', ' ')
# Create action part: icon + task + model
action_part = f"{icon} {task_display}:`{model_name}`"
action_parts.append(action_part)
# Join with arrows to show sequence
beautified += " ➡ ".join(action_parts) + "\n\n"
# Add summary information
beautified += f"**📋 Processing Pipeline:** {len(matches)} steps\n"
beautified += f"**🛠️ Models to use:** {' → '.join(models)}"
return beautified
def resize_image_to_original(processed_image_path, original_size):
"""
Resize processed image back to original dimensions
Args:
processed_image_path (str): Path to the processed image
original_size (tuple): Original image dimensions (width, height)
Returns:
str: Path to the resized image
"""
if processed_image_path and os.path.exists(processed_image_path):
img = Image.open(processed_image_path)
img_resized = img.resize(original_size, Image.Resampling.LANCZOS)
# Save resized image
output_path = os.path.join('temp_outputs', 'final_result.png')
img_resized.save(output_path)
return output_path
return processed_image_path
@spaces.GPU(duration=150)
def get_llm_response_streaming(image_path):
"""
Get streaming response from LLM for image analysis
Args:
image_path (str): Path to the input image
Returns:
TextIteratorStreamer: A streamer object to yield tokens
"""
# Select random prompt from the templates
instruction = prompts_query2[random.randint(0, len(prompts_query2)-1)]
# Format the prompt with image for multimodal input
prompt = (f"<|start_header_id|>user<|end_header_id|>\n\n\n{instruction}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
# Load and process image
raw_image = Image.open(image_path)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
# Setup streaming for token-by-token generation
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
# Generate response in a separate thread to avoid blocking
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=400,
do_sample=False
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
return streamer
@spaces.GPU(duration=150)
def process_image_with_tools(image_path, models, original_size):
"""
Process image using the tool engine and restore to original size
Args:
image_path (str): Path to the input image
models (list): List of models to apply
original_size (tuple): Original image dimensions
Returns:
str: Path to the final processed image
"""
if not models:
return None
# Create output directory
os.makedirs('temp_outputs', exist_ok=True)
# Process the image with selected models
res = tool_engine.process_image(models, image_path, 'temp_outputs')
# Resize back to original dimensions
final_result = resize_image_to_original(res['output_path'], original_size)
return final_result
def process_full_pipeline(image):
"""
Main processing pipeline with streaming UI updates
Args:
image (str): Path to the input image
Yields:
tuple: (chat_history, processed_image) for Gradio UI updates
"""
if image is None:
return [], None
try:
# Get original image size for later restoration
original_img = Image.open(image)
original_size = original_img.size
# Initialize chat history for UI
chat_history = [("Image uploaded for analysis", None)]
# Step 1: Get streaming LLM response
streamer = get_llm_response_streaming(image)
# Stream the response to UI with real-time updates
full_response = ""
in_reason = False
in_answer = False
reason_displayed = False
answer_displayed = False
reasoning_added = False # Track if reasoning entry was added
for new_text in streamer:
full_response += new_text
# Check if we're entering reason section or if we need to start showing content
if ('' in full_response and not in_reason and not reason_displayed) or (not reasoning_added and not in_reason and not reason_displayed):
in_reason = True
reasoning_added = True
if '' in full_response:
# Extract content after
reason_start = full_response.find('') + len('')
reason_content = full_response[reason_start:].strip()
else:
# Show all content as reasoning if no tag yet
reason_content = full_response.strip()
# Add reasoning to chat history
chat_history.append((None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}"))
yield chat_history, None
# If we're in reason section, update content
elif in_reason and not reason_displayed:
# Check if reason section is complete
if '' in full_response:
# Extract complete reason content
reason_start = full_response.find('') + len('')
reason_end = full_response.find('')
reason_content = full_response[reason_start:reason_end].strip()
# Update chat history with complete reason
chat_history[1] = (None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}")
reason_displayed = True
in_reason = False
yield chat_history, None
else:
# Continue streaming reason content
if '' in full_response:
reason_start = full_response.find('') + len('')
reason_content = full_response[reason_start:].strip()
else:
reason_content = full_response.strip()
# Update chat history with partial reason
chat_history[1] = (None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}")
yield chat_history, None
# Check if we're entering answer section
elif '' in full_response and not in_answer and not answer_displayed and reason_displayed:
in_answer = True
# Extract content after
answer_start = full_response.find('') + len('')
answer_content = full_response[answer_start:]
# Add partial answer to chat history
models = extract_models_from_answer(answer_content)
beautified = beautify_recommended_actions(answer_content, models)
chat_history.append((None, beautified))
yield chat_history, None
# If we're in answer section, update content
elif in_answer and not answer_displayed:
# Check if answer section is complete
if '' in full_response:
# Extract complete answer content
answer_start = full_response.find('') + len('')
answer_end = full_response.find('')
answer_content = full_response[answer_start:answer_end].strip()
# Parse and process final answer
models = extract_models_from_answer(answer_content)
beautified = beautify_recommended_actions(answer_content, models)
chat_history[2] = (None, beautified)
answer_displayed = True
in_answer = False
yield chat_history, None
# Process image with tools
if models:
chat_history.append((None, "**🔄 Processing image...**"))
yield chat_history, None
processed_image = process_image_with_tools(image, models, original_size)
chat_history[-1] = (None, "**✅ Processing Complete!**")
yield chat_history, processed_image
return
else:
chat_history.append((None, "**❌ No valid models found in the response**"))
yield chat_history, None
return
else:
# Continue streaming answer content
answer_start = full_response.find('') + len('')
answer_content = full_response[answer_start:].strip()
# Update chat history with partial answer
models = extract_models_from_answer(answer_content)
beautified = beautify_recommended_actions(answer_content, models)
chat_history[2] = (None, beautified)
yield chat_history, None
# Fallback if streaming completes without proper tags
if not answer_displayed:
reason, answer = parse_llm_response(full_response)
models = extract_models_from_answer(answer)
chat_history = [
("Image uploaded for analysis", None),
(None, f"**🤔 Analysis & Reasoning:**\n\n{reason}"),
(None, beautify_recommended_actions(answer, models))
]
if models:
chat_history.append((None, "**🔄 Processing image...**"))
yield chat_history, None
processed_image = process_image_with_tools(image, models, original_size)
chat_history[-1] = (None, "**✅ Processing Complete!**")
yield chat_history, processed_image
else:
chat_history.append((None, "**❌ No valid models found in the response**"))
yield chat_history, None
except Exception as e:
error_msg = f"Error: {str(e)}"
chat_history = [
("Image uploaded for analysis", None),
(None, f"**❌ Error occurred:**\n\n{error_msg}")
]
yield chat_history, None
# Create Gradio interface
def create_interface():
"""
Create and configure the Gradio web interface
Returns:
gr.Blocks: Configured Gradio interface
"""
with gr.Blocks(title="JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration", theme=gr.themes.Soft()) as demo:
# Header with logo and title
gr.Markdown("""
#
JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration
Upload an image and let JarvisIR analyze its degradation and recommend the best restoration tools!
""")
with gr.Row():
with gr.Column(scale=1):
# Input image upload component
input_image = gr.Image(
type="filepath",
label="📸 Upload Your Image",
height=400
)
# Process button
process_btn = gr.Button(
"🚀 Analyze & Process",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
# Chat interface to show analysis
chatbot = gr.Chatbot(
label="💬 AI Analysis Chat",
height=400,
show_label=True,
bubble_full_width=False
)
with gr.Row():
# Output image display
output_image = gr.Image(
label="✨ Processed Result",
height=300
)
# Connect event handler for the process button
process_btn.click(
fn=process_full_pipeline,
inputs=[input_image],
outputs=[chatbot, output_image]
)
# Instructions section
gr.Markdown("### 📝 Instructions:")
gr.Markdown("""
1. **Upload an image** that needs restoration (blurry, dark, noisy, etc.)
2. **Click 'Analyze & Process'** to let AI analyze the image
3. **View the chat** to see AI's reasoning and recommendations in real-time
4. **Check the result** - processed image restored to original dimensions
""")
return demo
if __name__ == "__main__":
print("Starting Image Restoration Assistant...")
demo = create_interface()
# Launch the Gradio app on specified host and port
demo.launch(
server_name="0.0.0.0",
server_port=7866,
share=False
)