File size: 22,267 Bytes
eea83e8
d9616a5
eea83e8
 
3d2f97b
eea83e8
 
 
 
fa3cf85
3d2f97b
fa3cf85
 
 
 
 
 
3d2f97b
fa3cf85
 
 
 
 
 
 
 
eea83e8
fa3cf85
eea83e8
fa3cf85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eea83e8
fa3cf85
 
 
 
 
 
 
 
 
eea83e8
 
 
fa3cf85
eea83e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3cf85
eea83e8
 
 
 
 
fa3cf85
eea83e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3cf85
eea83e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3cf85
eea83e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
import os
import spaces
import re
import random
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, TextIteratorStreamer
from threading import Thread
import subprocess

def install_cuda_toolkit():
    CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
    CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
    subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
    subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
    subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])

    os.environ["CUDA_HOME"] = "/usr/local/cuda"
    os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
    os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
        os.environ["CUDA_HOME"],
        "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
    )
    # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
    os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"

install_cuda_toolkit()

def download_tools_ckpts(target_dir, url):
    from huggingface_hub import snapshot_download
    import os
    import shutil

    tmp_dir = "hf_temp_download"
    os.makedirs(tmp_dir, exist_ok=True)

    snapshot_download(
        repo_id="LYL1015/JarvisIR",
        repo_type="model",
        local_dir=tmp_dir,
        allow_patterns=os.path.join(url, "**"), 
        local_dir_use_symlinks=False,
    )

    src_dir = os.path.join(tmp_dir, url)
    
 
    shutil.copytree(src_dir, target_dir) 

    shutil.rmtree(tmp_dir)

target_dir = "JarvisIR/checkpoints/agent_tools"
if not os.path.exists(target_dir):
    download_tools_ckpts(target_dir, "agent_tools/checkpoints")

llm_targer_dir = "JarvisIR/checkpoints/pretrained_preview"
if not os.path.exists(llm_targer_dir):
    download_tools_ckpts(llm_targer_dir, "pretrained/preview")

# Model configuration
# XXX: Path to the fine-tuned LLaVA model
model_id = llm_targer_dir 
    
# Available image restoration tasks and their corresponding models
all_tasks = " {denoise: [scunet, restormer], lighten: [retinexformer_fivek, hvicidnet, lightdiff], \
                derain: [idt, turbo_rain, s2former], defog:[ridcp, kanet], \
                desnow:[turbo_snow, snowmaster], super_resolution: [real_esrgan], \
            }"

# Various prompt templates for querying the LLM about image degradation and restoration tasks
prompts_query2 = [
    f"Considering the image's degradation, suggest the required tasks with explanations, and identify suitable tools for each task. Options for tasks and tools include: {all_tasks}.",
    f"Given the image's degradation, outline the essential tasks along with justifications, and choose the appropriate tools for each task from the following options: {all_tasks}.",
    f"Please specify the tasks required due to the image's degradation, explain the reasons, and select relevant tools for each task from the provided options: {all_tasks}.",
    f"Based on the image degradation, determine the necessary tasks and their reasons, along with the appropriate tools for each task. Choose from these options: {all_tasks}.",
    f"Identify the tasks required to address the image's degradation, including the reasons for each, and select tools from the options: {all_tasks}.",
    f"Considering the degradation observed, list the tasks needed and their justifications, then pick the most suitable tools for each task from these options: {all_tasks}.",
    f"Evaluate the image degradation, and based on that, provide the necessary tasks and reasons, along with tools chosen from the options: {all_tasks}.",
    f"With respect to the image degradation, outline the tasks needed and explain why, selecting tools from the following list: {all_tasks}.",
    f"Given the level of degradation in the image, specify tasks to address it, include reasons, and select tools for each task from: {all_tasks}.",
    f"Examine the image's degradation, propose relevant tasks and their explanations, and identify tools from the options provided: {all_tasks}.",
    f"Based on observed degradation, detail the tasks required, explain your choices, and select tools from these options: {all_tasks}.",
    f"Using the image's degradation as a guide, list the necessary tasks, include explanations, and pick tools from the provided choices: {all_tasks}.",
    f"Assess the image degradation, provide the essential tasks and reasons, and select the appropriate tools for each task from the options: {all_tasks}.",
    f"According to the image's degradation, determine which tasks are necessary and why, choosing tools for each task from: {all_tasks}.",
    f"Observe the degradation in the image, specify the needed tasks with justifications, and select appropriate tools from: {all_tasks}.",
    f"Taking the image degradation into account, specify tasks needed, provide reasons, and choose tools from the following: {all_tasks}.",
    f"Consider the image's degradation level, outline the tasks necessary, provide reasoning, and select suitable tools from: {all_tasks}.",
    f"Evaluate the degradation in the image, identify tasks required, explain your choices, and pick tools from: {all_tasks}.",
    f"Analyze the image degradation and suggest tasks with justifications, choosing the best tools from these options: {all_tasks}.",
    f"Review the image degradation, and based on it, specify tasks needed, provide reasons, and select tools for each task from: {all_tasks}."
]

# Initialize models
print("Loading LLM model...")

# Initialize the image restoration toolkit
from agent_tools import RestorationToolkit
tool_engine = RestorationToolkit(score_weight=[0,0,0,0,0])
# Load the LLaVA model in half precision to reduce memory usage
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",  
    low_cpu_mem_usage=True
)
processor = AutoProcessor.from_pretrained(model_id)

print("Loading tool engine...")

def parse_llm_response(response):
    """
    Parse the LLM response to extract reason and answer sections
    
    Args:
        response (str): The raw response from the LLM
        
    Returns:
        tuple: (reason, answer) extracted from the response
    """
    reason_match = re.search(r'<reason>(.*?)</reason>', response, re.DOTALL)
    answer_match = re.search(r'<answer>(.*?)</answer>', response, re.DOTALL)
    
    reason = reason_match.group(1).strip() if reason_match else "No reasoning provided"
    answer = answer_match.group(1).strip() if answer_match else "No answer provided"
    
    return reason, answer

def extract_models_from_answer(answer):
    """
    Extract model names from the answer string using regex
    
    Args:
        answer (str): The answer string containing model recommendations
        
    Returns:
        list: List of extracted model names
    """
    # Pattern to match [type:xxx]:(model:xxx)
    pattern = r'\[type:[^\]]+\]:\(model:([^)]+)\)'
    models = re.findall(pattern, answer)
    return models

def beautify_recommended_actions(answer, models):
    """
    Format the LLM's recommendations in a more visually appealing way
    
    Args:
        answer (str): The raw answer from LLM
        models (list): List of extracted model names
        
    Returns:
        str: Beautified display of recommendations
    """
    
    # Task type to emoji mapping for visual enhancement
    task_icons = {
        'denoise': '🧹',
        'lighten': '💡', 
        'derain': '🌧️',
        'defog': '🌫️',
        'desnow': '❄️',
        'super_resolution': '🔍'
    }
    
    # Parse the answer to extract tasks and models
    pattern = r'\[type:([^\]]+)\]:\(model:([^)]+)\)'
    matches = re.findall(pattern, answer)
    
    if not matches:
        return f"**🎯 Recommended Actions:**\n\n{answer}\n\n**Extracted Models:** {', '.join(models) if models else 'None'}"
    
    # Create beautified display
    beautified = "**🎯 Recommended Actions:**\n"
    beautified += "> "
    
    # Create horizontal flow of actions
    action_parts = []
    for task_type, model_name in matches:
        task_type = task_type.strip()
        model_name = model_name.strip()
        
        # Get icon for task type
        icon = task_icons.get(task_type, '🔧')
        
        # Format task name (capitalize and replace underscores)
        task_display = task_type.title().replace('_', ' ')
        
        # Create action part: icon + task + model
        action_part = f"{icon} {task_display}:`{model_name}`"
        action_parts.append(action_part)
    
    # Join with arrows to show sequence
    beautified += " ➡ ".join(action_parts) + "\n\n"
    
    # Add summary information
    beautified += f"**📋 Processing Pipeline:** {len(matches)} steps\n"
    beautified += f"**🛠️ Models to use:** {' → '.join(models)}"
    
    return beautified

def resize_image_to_original(processed_image_path, original_size):
    """
    Resize processed image back to original dimensions
    
    Args:
        processed_image_path (str): Path to the processed image
        original_size (tuple): Original image dimensions (width, height)
        
    Returns:
        str: Path to the resized image
    """
    if processed_image_path and os.path.exists(processed_image_path):
        img = Image.open(processed_image_path)
        img_resized = img.resize(original_size, Image.Resampling.LANCZOS)
        
        # Save resized image
        output_path = os.path.join('temp_outputs', 'final_result.png')
        img_resized.save(output_path)
        return output_path
    return processed_image_path

def get_llm_response_streaming(image_path):
    """
    Get streaming response from LLM for image analysis
    
    Args:
        image_path (str): Path to the input image
        
    Returns:
        TextIteratorStreamer: A streamer object to yield tokens
    """
    # Select random prompt from the templates
    instruction = prompts_query2[random.randint(0, len(prompts_query2)-1)]
    
    # Format the prompt with image for multimodal input
    prompt = (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{instruction}<|eot_id|>"
              "<|start_header_id|>assistant<|end_header_id|>\n\n")
    
    # Load and process image
    raw_image = Image.open(image_path)
    inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
    
    # Setup streaming for token-by-token generation
    streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Generate response in a separate thread to avoid blocking
    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=400,
        do_sample=False
    )
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    return streamer


def process_image_with_tools(image_path, models, original_size):
    """
    Process image using the tool engine and restore to original size
    
    Args:
        image_path (str): Path to the input image
        models (list): List of models to apply
        original_size (tuple): Original image dimensions
        
    Returns:
        str: Path to the final processed image
    """
    if not models:
        return None
    
    # Create output directory
    os.makedirs('temp_outputs', exist_ok=True)
    
    # Process the image with selected models
    res = tool_engine.process_image(models, image_path, 'temp_outputs')
    
    # Resize back to original dimensions
    final_result = resize_image_to_original(res['output_path'], original_size)
    
    return final_result
@spaces.GPU(duration=150)
def process_full_pipeline(image):
    """
    Main processing pipeline with streaming UI updates
    
    Args:
        image (str): Path to the input image
        
    Yields:
        tuple: (chat_history, processed_image) for Gradio UI updates
    """
    if image is None:
        return [], None
    
    try:
        # Get original image size for later restoration
        original_img = Image.open(image)
        original_size = original_img.size
        
        # Initialize chat history for UI
        chat_history = [("Image uploaded for analysis", None)]
        
        # Step 1: Get streaming LLM response
        streamer = get_llm_response_streaming(image)
        
        # Stream the response to UI with real-time updates
        full_response = ""
        in_reason = False
        in_answer = False
        reason_displayed = False
        answer_displayed = False
        reasoning_added = False  # Track if reasoning entry was added
        
        for new_text in streamer:
            full_response += new_text
            
            # Check if we're entering reason section or if we need to start showing content
            if ('<reason>' in full_response and not in_reason and not reason_displayed) or (not reasoning_added and not in_reason and not reason_displayed):
                in_reason = True
                reasoning_added = True
                
                if '<reason>' in full_response:
                    # Extract content after <reason>
                    reason_start = full_response.find('<reason>') + len('<reason>')
                    reason_content = full_response[reason_start:].strip()
                else:
                    # Show all content as reasoning if no tag yet
                    reason_content = full_response.strip()
                
                # Add reasoning to chat history
                chat_history.append((None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}"))
                yield chat_history, None
            
            # If we're in reason section, update content
            elif in_reason and not reason_displayed:
                # Check if reason section is complete
                if '</reason>' in full_response:
                    # Extract complete reason content
                    reason_start = full_response.find('<reason>') + len('<reason>')
                    reason_end = full_response.find('</reason>')
                    reason_content = full_response[reason_start:reason_end].strip()
                    
                    # Update chat history with complete reason
                    chat_history[1] = (None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}")
                    reason_displayed = True
                    in_reason = False
                    yield chat_history, None
                else:
                    # Continue streaming reason content
                    if '<reason>' in full_response:
                        reason_start = full_response.find('<reason>') + len('<reason>')
                        reason_content = full_response[reason_start:].strip()
                    else:
                        reason_content = full_response.strip()
                    
                    # Update chat history with partial reason
                    chat_history[1] = (None, f"**🤔 Analysis & Reasoning:**\n\n{reason_content}")
                    yield chat_history, None
            
            # Check if we're entering answer section
            elif '<answer>' in full_response and not in_answer and not answer_displayed and reason_displayed:
                in_answer = True
                # Extract content after <answer>
                answer_start = full_response.find('<answer>') + len('<answer>')
                answer_content = full_response[answer_start:]
                
                # Add partial answer to chat history
                models = extract_models_from_answer(answer_content)
                beautified = beautify_recommended_actions(answer_content, models)
                chat_history.append((None, beautified))
                yield chat_history, None
            
            # If we're in answer section, update content
            elif in_answer and not answer_displayed:
                # Check if answer section is complete
                if '</answer>' in full_response:
                    # Extract complete answer content
                    answer_start = full_response.find('<answer>') + len('<answer>')
                    answer_end = full_response.find('</answer>')
                    answer_content = full_response[answer_start:answer_end].strip()
                    
                    # Parse and process final answer
                    models = extract_models_from_answer(answer_content)
                    beautified = beautify_recommended_actions(answer_content, models)
                    chat_history[2] = (None, beautified)
                    answer_displayed = True
                    in_answer = False
                    yield chat_history, None
                    
                    # Process image with tools
                    if models:
                        chat_history.append((None, "**🔄 Processing image...**"))
                        yield chat_history, None
                        
                        processed_image = process_image_with_tools(image, models, original_size)
                        chat_history[-1] = (None, "**✅ Processing Complete!**")
                        yield chat_history, processed_image
                        return
                    else:
                        chat_history.append((None, "**❌ No valid models found in the response**"))
                        yield chat_history, None
                        return
                else:
                    # Continue streaming answer content
                    answer_start = full_response.find('<answer>') + len('<answer>')
                    answer_content = full_response[answer_start:].strip()
                    
                    # Update chat history with partial answer
                    models = extract_models_from_answer(answer_content)
                    beautified = beautify_recommended_actions(answer_content, models)
                    chat_history[2] = (None, beautified)
                    yield chat_history, None
        
        # Fallback if streaming completes without proper tags
        if not answer_displayed:
            reason, answer = parse_llm_response(full_response)
            models = extract_models_from_answer(answer)
            
            chat_history = [
                ("Image uploaded for analysis", None),
                (None, f"**🤔 Analysis & Reasoning:**\n\n{reason}"),
                (None, beautify_recommended_actions(answer, models))
            ]
            
            if models:
                chat_history.append((None, "**🔄 Processing image...**"))
                yield chat_history, None
                
                processed_image = process_image_with_tools(image, models, original_size)
                chat_history[-1] = (None, "**✅ Processing Complete!**")
                yield chat_history, processed_image
            else:
                chat_history.append((None, "**❌ No valid models found in the response**"))
                yield chat_history, None
            
    except Exception as e:
        error_msg = f"Error: {str(e)}"
        chat_history = [
            ("Image uploaded for analysis", None),
            (None, f"**❌ Error occurred:**\n\n{error_msg}")
        ]
        yield chat_history, None

# Create Gradio interface
def create_interface():
    """
    Create and configure the Gradio web interface
    
    Returns:
        gr.Blocks: Configured Gradio interface
    """
    with gr.Blocks(title="JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration", theme=gr.themes.Soft()) as demo:
        # Header with logo and title
        gr.Markdown("""
        # <img src="https://cvpr2025-jarvisir.github.io/imgs/icon.png" width="32" height="32" style="display: inline-block; vertical-align: middle; transform: translateY(-2px); margin-right: 1px;"/> JarvisIR: Elevating Autonomous Driving Perception with Intelligent Image Restoration
        
        Upload an image and let JarvisIR analyze its degradation and recommend the best restoration tools!
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                # Input image upload component
                input_image = gr.Image(
                    type="filepath", 
                    label="📸 Upload Your Image",
                    height=400
                )
                
                # Process button
                process_btn = gr.Button(
                    "🚀 Analyze & Process", 
                    variant="primary",
                    size="lg"
                )
                
            with gr.Column(scale=1):
                # Chat interface to show analysis
                chatbot = gr.Chatbot(
                    label="💬 AI Analysis Chat",
                    height=400,
                    show_label=True,
                    bubble_full_width=False
                )
        
        with gr.Row():
            # Output image display
            output_image = gr.Image(
                label="✨ Processed Result", 
                height=300
            )
        
        # Connect event handler for the process button
        process_btn.click(
            fn=process_full_pipeline,
            inputs=[input_image],
            outputs=[chatbot, output_image]
        )
        
        # Instructions section
        gr.Markdown("### 📝 Instructions:")
        gr.Markdown("""
        1. **Upload an image** that needs restoration (blurry, dark, noisy, etc.)
        2. **Click 'Analyze & Process'** to let AI analyze the image
        3. **View the chat** to see AI's reasoning and recommendations in real-time
        4. **Check the result** - processed image restored to original dimensions
        """)
    
    return demo

if __name__ == "__main__":
    print("Starting Image Restoration Assistant...")
    demo = create_interface()
    # Launch the Gradio app on specified host and port
    demo.launch(
        server_name="0.0.0.0",
        server_port=7866,
        share=False
    )