obichimav commited on
Commit
2541bac
·
verified ·
1 Parent(s): 2bd484d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -280
app.py CHANGED
@@ -328,300 +328,179 @@
328
 
329
  # demo.launch(share=True)
330
 
331
- # imports
332
  import os
333
- import json
334
- import base64
335
- from io import BytesIO
336
- from dotenv import load_dotenv
337
- from openai import OpenAI
338
  import gradio as gr
339
- import numpy as np
340
- from PIL import Image, ImageDraw
341
- import requests
342
- import matplotlib.pyplot as plt
 
 
343
  from vision_agent.agent import VisionAgentCoderV2
344
  from vision_agent.models import AgentMessage
345
- import vision_agent.tools as T
346
-
347
- # Initialization
348
- load_dotenv()
349
- os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here')
350
- os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-anthropic-key-here')
351
- PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
352
- MODEL = "gpt-4o"
353
- openai = OpenAI()
354
-
355
- # Initialize VisionAgent (kept for potential future use, though not used directly in detection below)
356
- agent = VisionAgentCoderV2(verbose=False)
357
-
358
- system_message = """You are an expert in object detection. When users mention:
359
- 1. "count [object(s)]" - Use detect_objects to count them
360
- 2. "detect [object(s)]" - Same as count
361
- 3. "show [object(s)]" - Same as count
362
-
363
- Always use object detection tool when counting/detecting is mentioned.
364
- Always be accurate. If you don't know the answer, say so."""
365
-
366
- class State:
367
- def __init__(self):
368
- self.current_image = None
369
- self.last_prediction = None
370
-
371
- state = State()
372
-
373
- def encode_image_to_base64(image_array):
374
- if image_array is None:
375
- return None
376
- image = Image.fromarray(image_array)
377
- buffered = BytesIO()
378
- image.save(buffered, format="JPEG")
379
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
380
-
381
- def save_temp_image(image_array):
382
- """Save the image to a temporary file for VisionAgent to process"""
383
- temp_path = "temp_image.jpg"
384
- image = Image.fromarray(image_array)
385
- image.save(temp_path)
386
- return temp_path
387
-
388
- def detect_objects(query_text):
389
- if state.current_image is None:
390
- return {"count": 0, "message": "No image provided"}
391
-
392
- # Save the current image to a temporary file
393
- image_path = save_temp_image(state.current_image)
394
-
395
- try:
396
- # Clean query text to get the object name
397
- object_name = query_text[0].replace("a photo of ", "").strip()
398
-
399
- # Load the image for detection and visualization
400
- image = T.load_image(image_path)
401
-
402
- # Use the specialized detector first with a threshold of 0.55
403
- detections = T.countgd_object_detection(object_name, image, conf_threshold=0.55)
404
- if detections is None:
405
- detections = []
406
-
407
- # If no detections, try the more general grounding_dino detector
408
- if not detections:
409
- try:
410
- detections = T.grounding_dino_detection(object_name, image, box_threshold=0.55)
411
- if detections is None:
412
- detections = []
413
- except Exception as e:
414
- print(f"Error in grounding_dino_detection: {str(e)}")
415
- detections = []
416
-
417
- # Only keep detections with confidence higher than 0.55
418
- high_conf_detections = [det for det in detections if det.get("score", 0) >= 0.55]
419
-
420
- # Visualize the high confidence detections with clear labeling
421
- result_image = T.overlay_bounding_boxes(
422
- image,
423
- high_conf_detections,
424
- labels=[f"{object_name}: {det['score']:.2f}" for det in high_conf_detections]
425
- )
426
-
427
- # Convert result back to numpy array for display
428
- state.last_prediction = np.array(result_image)
429
-
430
- return {
431
- "count": len(high_conf_detections),
432
- "confidence": [det["score"] for det in high_conf_detections],
433
- "message": f"Detected {len(high_conf_detections)} {object_name}(s) with high confidence (>=0.55)"
434
- }
435
- except Exception as e:
436
- print(f"Error in detect_objects: {str(e)}")
437
- return {"count": 0, "message": f"Error: {str(e)}"}
438
-
439
- def identify_plant():
440
- if state.current_image is None:
441
- return {"error": "No image provided"}
442
-
443
- image = Image.fromarray(state.current_image)
444
- img_byte_arr = BytesIO()
445
- image.save(img_byte_arr, format='JPEG')
446
- img_byte_arr = img_byte_arr.getvalue()
447
-
448
- api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}"
449
- files = [('images', ('image.jpg', img_byte_arr))]
450
- data = {'organs': ['leaf']}
451
-
452
- try:
453
- response = requests.post(api_endpoint, files=files, data=data)
454
- if response.status_code == 200:
455
- result = response.json()
456
- best_match = result['results'][0]
457
- return {
458
- "scientific_name": best_match['species']['scientificName'],
459
- "common_names": best_match['species'].get('commonNames', []),
460
- "family": best_match['species']['family']['scientificName'],
461
- "genus": best_match['species']['genus']['scientificName'],
462
- "confidence": f"{best_match['score']*100:.1f}%"
463
- }
464
- else:
465
- return {"error": f"API Error: {response.status_code}"}
466
- except Exception as e:
467
- return {"error": f"Error: {str(e)}"}
468
-
469
- # Tool definitions
470
- object_detection_function = {
471
- "name": "detect_objects",
472
- "description": "Use this function to detect and count objects in images based on text queries.",
473
- "parameters": {
474
- "type": "object",
475
- "properties": {
476
- "query_text": {
477
- "type": "array",
478
- "description": "List of text queries describing objects to detect",
479
- "items": {"type": "string"}
480
- }
481
- }
482
- }
483
- }
484
-
485
- plant_identification_function = {
486
- "name": "identify_plant",
487
- "description": "Use this when asked about plant species identification or botanical classification.",
488
- "parameters": {
489
- "type": "object",
490
- "properties": {},
491
- "required": []
492
- }
493
- }
494
-
495
- tools = [
496
- {"type": "function", "function": object_detection_function},
497
- {"type": "function", "function": plant_identification_function}
498
- ]
499
-
500
- def format_tool_response(tool_response_content):
501
- data = json.loads(tool_response_content)
502
- if "error" in data:
503
- return f"Error: {data['error']}"
504
- elif "scientific_name" in data:
505
- return f"""📋 Plant Identification Results:
506
-
507
- 🌿 Scientific Name: {data['scientific_name']}
508
- 👥 Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'}
509
- 👪 Family: {data['family']}
510
- 🎯 Confidence: {data['confidence']}"""
511
- else:
512
- return f"I detected {data['count']} objects in the image."
513
 
514
- def chat(message, image, history):
515
- if image is not None:
516
- state.current_image = image
517
-
518
- if state.current_image is None:
519
- return "Please upload an image first.", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
 
521
- base64_image = encode_image_to_base64(state.current_image)
522
- messages = [{"role": "system", "content": system_message}]
 
 
 
523
 
524
- for human, assistant in history:
525
- messages.append({"role": "user", "content": human})
526
- messages.append({"role": "assistant", "content": assistant})
 
 
 
527
 
528
- # Extract objects to detect from user message
529
- objects_to_detect = message.lower()
530
- cleaned_query = objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()
531
- query = ["a photo of " + cleaned_query]
532
 
533
- messages.append({
534
- "role": "user",
535
- "content": [
536
- {"type": "text", "text": message},
537
- {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
538
- ]
539
- })
540
 
541
- response = openai.chat.completions.create(
542
- model=MODEL,
543
- messages=messages,
544
- tools=tools,
545
- max_tokens=300
546
- )
547
 
548
- # Check if a tool call is required based on the response
549
- if response.choices[0].finish_reason == "tool_calls":
550
- message_obj = response.choices[0].message
551
- messages.append(message_obj)
552
-
553
- # Process each tool call from the message
554
- for tool_call in message_obj.tool_calls:
555
- if tool_call.function.name == "detect_objects":
556
- results = detect_objects(query)
557
- else:
558
- results = identify_plant()
559
-
560
- tool_response = {
561
- "role": "tool",
562
- "content": json.dumps(results),
563
- "tool_call_id": tool_call.id
564
- }
565
- messages.append(tool_response)
566
-
567
- response = openai.chat.completions.create(
568
- model=MODEL,
569
- messages=messages,
570
- max_tokens=300
571
- )
572
-
573
- return response.choices[0].message.content, state.last_prediction
574
-
575
- # Create Gradio interface
576
  with gr.Blocks() as demo:
577
- gr.Markdown("# Object Detection and Plant Analysis System using VisionAgent")
 
 
 
 
 
 
 
578
 
579
  with gr.Row():
580
- with gr.Column():
581
- image_input = gr.Image(type="numpy", label="Upload Image")
582
- text_input = gr.Textbox(
583
- label="Ask about the image",
584
- placeholder="e.g., 'Count dogs in this image' or 'What species is this plant?'"
585
- )
586
- with gr.Row():
587
- submit_btn = gr.Button("Analyze")
588
- reset_btn = gr.Button("Reset")
589
-
590
- with gr.Column():
591
- chatbot = gr.Chatbot()
592
- output_image = gr.Image(type="numpy", label="Detection Results")
593
-
594
- def process_interaction(message, image, history):
595
- response_text, pred_image = chat(message, image, history)
596
- history.append((message, response_text))
597
- return "", pred_image, history
598
-
599
- def reset_interface():
600
- state.current_image = None
601
- state.last_prediction = None
602
- return None, None, None, []
603
-
604
- submit_btn.click(
605
- fn=process_interaction,
606
- inputs=[text_input, image_input, chatbot],
607
- outputs=[text_input, output_image, chatbot]
608
- )
609
 
610
- reset_btn.click(
611
- fn=reset_interface,
612
- inputs=[],
613
- outputs=[image_input, output_image, text_input, chatbot]
614
- )
615
-
616
- gr.Markdown("""## Instructions
617
- 1. Upload an image
618
- 2. Ask specific questions about objects or plants
619
- 3. Click Analyze to get results
620
 
621
- Examples:
622
- - "Count the number of people in this image"
623
- - "Detect cats and dogs"
624
- - "What species is this plant?"
625
- """)
626
 
627
- demo.launch(share=True)
 
328
 
329
  # demo.launch(share=True)
330
 
 
331
  import os
332
+ import re
333
+ import io
334
+ import uuid
335
+ import contextlib
 
336
  import gradio as gr
337
+ from PIL import Image
338
+ import shutil
339
+
340
+ # Required packages:
341
+ # pip install vision-agent gradio openai anthropic
342
+
343
  from vision_agent.agent import VisionAgentCoderV2
344
  from vision_agent.models import AgentMessage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
+ #############################################
347
+ # GLOBAL INITIALIZATION
348
+ #############################################
349
+
350
+ # Create a unique temporary directory for saved images
351
+ TEMP_DIR = "temp_images"
352
+ if not os.path.exists(TEMP_DIR):
353
+ os.makedirs(TEMP_DIR)
354
+
355
+ # Initialize VisionAgentCoderV2 with verbose logging so the generated code has detailed print outputs.
356
+ agent = VisionAgentCoderV2(verbose=True)
357
+
358
+ #############################################
359
+ # UTILITY: SAVE UPLOADED IMAGE TO A TEMP FILE
360
+ #############################################
361
+
362
+ def save_uploaded_image(image):
363
+ """
364
+ Saves the uploaded image (a numpy array) to a temporary file.
365
+ Returns the filename (including path) to be passed as media to VisionAgent.
366
+ """
367
+ # Generate a unique filename
368
+ filename = os.path.join(TEMP_DIR, f"{uuid.uuid4().hex}.jpg")
369
+ im = Image.fromarray(image)
370
+ im.save(filename)
371
+ return filename
372
+
373
+ #############################################
374
+ # UTILITY: PARSE FILENAMES FROM save_image(...)
375
+ #############################################
376
+
377
+ def parse_saved_image_filenames(code_str):
378
+ """
379
+ Find all filenames in lines that look like:
380
+ save_image(..., 'filename.jpg')
381
+ Returns a list of the extracted filenames.
382
+ """
383
+ pattern = r"save_image\s*\(\s*[^,]+,\s*'([^']+)'\s*\)"
384
+ return re.findall(pattern, code_str)
385
+
386
+ #############################################
387
+ # UTILITY: EXECUTE CODE, CAPTURE STDOUT, IDENTIFY IMAGES
388
+ #############################################
389
+
390
+ def run_and_capture_with_images(code_str):
391
+ """
392
+ Executes the given code_str, capturing stdout and returning:
393
+ - output: a string with all print statements (the step logs)
394
+ - existing_images: list of filenames that were saved and exist on disk.
395
+ """
396
+ # Parse the code for image filenames saved via save_image
397
+ filenames = parse_saved_image_filenames(code_str)
398
+
399
+ # Capture stdout using a StringIO buffer
400
+ buf = io.StringIO()
401
+ with contextlib.redirect_stdout(buf):
402
+ # IMPORTANT: Here we exec the generated code.
403
+ exec(code_str, globals(), locals())
404
+
405
+ # Gather all printed output
406
+ output = buf.getvalue()
407
+
408
+ # Check which of the parsed filenames exist on disk (prepend TEMP_DIR if needed)
409
+ existing_images = []
410
+ for fn in filenames:
411
+ # If filename is not an absolute path, assume it is in TEMP_DIR
412
+ if not os.path.isabs(fn):
413
+ fn = os.path.join(TEMP_DIR, fn)
414
+ if os.path.exists(fn):
415
+ existing_images.append(fn)
416
+ return output, existing_images
417
+
418
+ #############################################
419
+ # CHAT FUNCTION: PROCESS USER PROMPT & IMAGE
420
+ #############################################
421
+
422
+ def chat(prompt, image, history):
423
+ """
424
+ When the user sends a prompt and optionally an image, do the following:
425
+ 1. Save the image to a temp file.
426
+ 2. Use VisionAgentCoderV2 to generate code for the task.
427
+ 3. Execute the generated code, capturing its stdout logs and any saved image files.
428
+ 4. Append the logs and image gallery info to the conversation history.
429
+ """
430
+ # Validate that an image was provided.
431
+ if image is None:
432
+ history.append(("System", "Please upload an image."))
433
+ return history, None
434
+
435
+ # Save the uploaded image for use in the generated code.
436
+ image_path = save_uploaded_image(image)
437
+
438
+ # Generate the code with VisionAgent using the user prompt and the image filename.
439
+ code_context = agent.generate_code(
440
+ [
441
+ AgentMessage(
442
+ role="user",
443
+ content=prompt,
444
+ media=[image_path]
445
+ )
446
+ ]
447
+ )
448
 
449
+ # Combine the generated code and its test snippet.
450
+ generated_code = code_context.code + "\n" + code_context.test
451
+
452
+ # Run the generated code and capture output and any saved images.
453
+ stdout_text, image_files = run_and_capture_with_images(generated_code)
454
 
455
+ # Format the response text (the captured logs).
456
+ response_text = f"**Execution Logs:**\n{stdout_text}\n"
457
+ if image_files:
458
+ response_text += "\n**Saved Images:** " + ", ".join(image_files)
459
+ else:
460
+ response_text += "\nNo images were saved by the generated code."
461
 
462
+ # Append the prompt and response to the chat history.
463
+ history.append((prompt, response_text))
 
 
464
 
465
+ # Optionally, you could clear the image input after use.
466
+ return history, image_files
 
 
 
 
 
467
 
468
+ #############################################
469
+ # GRADIO CHAT INTERFACE
470
+ #############################################
 
 
 
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  with gr.Blocks() as demo:
473
+ gr.Markdown("# VisionAgent Chat App")
474
+ gr.Markdown(
475
+ """
476
+ This chat app lets you enter a prompt (e.g., "Count the number of cacao oranges in the image")
477
+ along with an image. The app then uses VisionAgentCoderV2 to generate multi-step code, executes it,
478
+ and returns the detailed logs and any saved images.
479
+ """
480
+ )
481
 
482
  with gr.Row():
483
+ with gr.Column(scale=7):
484
+ chatbot = gr.Chatbot(label="Chat History")
485
+ prompt_input = gr.Textbox(label="Enter Prompt", placeholder="e.g., Count the number of cacao oranges in the image")
486
+ submit_btn = gr.Button("Send")
487
+ with gr.Column(scale=5):
488
+ image_input = gr.Image(label="Upload Image", type="numpy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ gallery = gr.Gallery(label="Generated Images").style(grid=[2], height="auto")
491
+
492
+ # Clear chat history button
493
+ clear_btn = gr.Button("Clear Chat")
494
+
495
+ # Chat function wrapper (it takes current chat history, prompt, image)
496
+ def user_chat_wrapper(prompt, image, history):
497
+ history = history or []
498
+ history, image_files = chat(prompt, image, history)
499
+ return history, image_files
500
 
501
+ submit_btn.click(fn=user_chat_wrapper, inputs=[prompt_input, image_input, chatbot], outputs=[chatbot, gallery])
502
+
503
+ clear_btn.click(lambda: ([], None), None, [chatbot, gallery])
504
+
505
+ demo.launch()
506