Update app.py
Browse files
app.py
CHANGED
@@ -328,300 +328,179 @@
|
|
328 |
|
329 |
# demo.launch(share=True)
|
330 |
|
331 |
-
# imports
|
332 |
import os
|
333 |
-
import
|
334 |
-
import
|
335 |
-
|
336 |
-
|
337 |
-
from openai import OpenAI
|
338 |
import gradio as gr
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
|
|
|
|
343 |
from vision_agent.agent import VisionAgentCoderV2
|
344 |
from vision_agent.models import AgentMessage
|
345 |
-
import vision_agent.tools as T
|
346 |
-
|
347 |
-
# Initialization
|
348 |
-
load_dotenv()
|
349 |
-
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here')
|
350 |
-
os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-anthropic-key-here')
|
351 |
-
PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
|
352 |
-
MODEL = "gpt-4o"
|
353 |
-
openai = OpenAI()
|
354 |
-
|
355 |
-
# Initialize VisionAgent (kept for potential future use, though not used directly in detection below)
|
356 |
-
agent = VisionAgentCoderV2(verbose=False)
|
357 |
-
|
358 |
-
system_message = """You are an expert in object detection. When users mention:
|
359 |
-
1. "count [object(s)]" - Use detect_objects to count them
|
360 |
-
2. "detect [object(s)]" - Same as count
|
361 |
-
3. "show [object(s)]" - Same as count
|
362 |
-
|
363 |
-
Always use object detection tool when counting/detecting is mentioned.
|
364 |
-
Always be accurate. If you don't know the answer, say so."""
|
365 |
-
|
366 |
-
class State:
|
367 |
-
def __init__(self):
|
368 |
-
self.current_image = None
|
369 |
-
self.last_prediction = None
|
370 |
-
|
371 |
-
state = State()
|
372 |
-
|
373 |
-
def encode_image_to_base64(image_array):
|
374 |
-
if image_array is None:
|
375 |
-
return None
|
376 |
-
image = Image.fromarray(image_array)
|
377 |
-
buffered = BytesIO()
|
378 |
-
image.save(buffered, format="JPEG")
|
379 |
-
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
380 |
-
|
381 |
-
def save_temp_image(image_array):
|
382 |
-
"""Save the image to a temporary file for VisionAgent to process"""
|
383 |
-
temp_path = "temp_image.jpg"
|
384 |
-
image = Image.fromarray(image_array)
|
385 |
-
image.save(temp_path)
|
386 |
-
return temp_path
|
387 |
-
|
388 |
-
def detect_objects(query_text):
|
389 |
-
if state.current_image is None:
|
390 |
-
return {"count": 0, "message": "No image provided"}
|
391 |
-
|
392 |
-
# Save the current image to a temporary file
|
393 |
-
image_path = save_temp_image(state.current_image)
|
394 |
-
|
395 |
-
try:
|
396 |
-
# Clean query text to get the object name
|
397 |
-
object_name = query_text[0].replace("a photo of ", "").strip()
|
398 |
-
|
399 |
-
# Load the image for detection and visualization
|
400 |
-
image = T.load_image(image_path)
|
401 |
-
|
402 |
-
# Use the specialized detector first with a threshold of 0.55
|
403 |
-
detections = T.countgd_object_detection(object_name, image, conf_threshold=0.55)
|
404 |
-
if detections is None:
|
405 |
-
detections = []
|
406 |
-
|
407 |
-
# If no detections, try the more general grounding_dino detector
|
408 |
-
if not detections:
|
409 |
-
try:
|
410 |
-
detections = T.grounding_dino_detection(object_name, image, box_threshold=0.55)
|
411 |
-
if detections is None:
|
412 |
-
detections = []
|
413 |
-
except Exception as e:
|
414 |
-
print(f"Error in grounding_dino_detection: {str(e)}")
|
415 |
-
detections = []
|
416 |
-
|
417 |
-
# Only keep detections with confidence higher than 0.55
|
418 |
-
high_conf_detections = [det for det in detections if det.get("score", 0) >= 0.55]
|
419 |
-
|
420 |
-
# Visualize the high confidence detections with clear labeling
|
421 |
-
result_image = T.overlay_bounding_boxes(
|
422 |
-
image,
|
423 |
-
high_conf_detections,
|
424 |
-
labels=[f"{object_name}: {det['score']:.2f}" for det in high_conf_detections]
|
425 |
-
)
|
426 |
-
|
427 |
-
# Convert result back to numpy array for display
|
428 |
-
state.last_prediction = np.array(result_image)
|
429 |
-
|
430 |
-
return {
|
431 |
-
"count": len(high_conf_detections),
|
432 |
-
"confidence": [det["score"] for det in high_conf_detections],
|
433 |
-
"message": f"Detected {len(high_conf_detections)} {object_name}(s) with high confidence (>=0.55)"
|
434 |
-
}
|
435 |
-
except Exception as e:
|
436 |
-
print(f"Error in detect_objects: {str(e)}")
|
437 |
-
return {"count": 0, "message": f"Error: {str(e)}"}
|
438 |
-
|
439 |
-
def identify_plant():
|
440 |
-
if state.current_image is None:
|
441 |
-
return {"error": "No image provided"}
|
442 |
-
|
443 |
-
image = Image.fromarray(state.current_image)
|
444 |
-
img_byte_arr = BytesIO()
|
445 |
-
image.save(img_byte_arr, format='JPEG')
|
446 |
-
img_byte_arr = img_byte_arr.getvalue()
|
447 |
-
|
448 |
-
api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}"
|
449 |
-
files = [('images', ('image.jpg', img_byte_arr))]
|
450 |
-
data = {'organs': ['leaf']}
|
451 |
-
|
452 |
-
try:
|
453 |
-
response = requests.post(api_endpoint, files=files, data=data)
|
454 |
-
if response.status_code == 200:
|
455 |
-
result = response.json()
|
456 |
-
best_match = result['results'][0]
|
457 |
-
return {
|
458 |
-
"scientific_name": best_match['species']['scientificName'],
|
459 |
-
"common_names": best_match['species'].get('commonNames', []),
|
460 |
-
"family": best_match['species']['family']['scientificName'],
|
461 |
-
"genus": best_match['species']['genus']['scientificName'],
|
462 |
-
"confidence": f"{best_match['score']*100:.1f}%"
|
463 |
-
}
|
464 |
-
else:
|
465 |
-
return {"error": f"API Error: {response.status_code}"}
|
466 |
-
except Exception as e:
|
467 |
-
return {"error": f"Error: {str(e)}"}
|
468 |
-
|
469 |
-
# Tool definitions
|
470 |
-
object_detection_function = {
|
471 |
-
"name": "detect_objects",
|
472 |
-
"description": "Use this function to detect and count objects in images based on text queries.",
|
473 |
-
"parameters": {
|
474 |
-
"type": "object",
|
475 |
-
"properties": {
|
476 |
-
"query_text": {
|
477 |
-
"type": "array",
|
478 |
-
"description": "List of text queries describing objects to detect",
|
479 |
-
"items": {"type": "string"}
|
480 |
-
}
|
481 |
-
}
|
482 |
-
}
|
483 |
-
}
|
484 |
-
|
485 |
-
plant_identification_function = {
|
486 |
-
"name": "identify_plant",
|
487 |
-
"description": "Use this when asked about plant species identification or botanical classification.",
|
488 |
-
"parameters": {
|
489 |
-
"type": "object",
|
490 |
-
"properties": {},
|
491 |
-
"required": []
|
492 |
-
}
|
493 |
-
}
|
494 |
-
|
495 |
-
tools = [
|
496 |
-
{"type": "function", "function": object_detection_function},
|
497 |
-
{"type": "function", "function": plant_identification_function}
|
498 |
-
]
|
499 |
-
|
500 |
-
def format_tool_response(tool_response_content):
|
501 |
-
data = json.loads(tool_response_content)
|
502 |
-
if "error" in data:
|
503 |
-
return f"Error: {data['error']}"
|
504 |
-
elif "scientific_name" in data:
|
505 |
-
return f"""📋 Plant Identification Results:
|
506 |
-
|
507 |
-
🌿 Scientific Name: {data['scientific_name']}
|
508 |
-
👥 Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'}
|
509 |
-
👪 Family: {data['family']}
|
510 |
-
🎯 Confidence: {data['confidence']}"""
|
511 |
-
else:
|
512 |
-
return f"I detected {data['count']} objects in the image."
|
513 |
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
|
521 |
-
|
522 |
-
|
|
|
|
|
|
|
523 |
|
524 |
-
|
525 |
-
|
526 |
-
|
|
|
|
|
|
|
527 |
|
528 |
-
#
|
529 |
-
|
530 |
-
cleaned_query = objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()
|
531 |
-
query = ["a photo of " + cleaned_query]
|
532 |
|
533 |
-
|
534 |
-
|
535 |
-
"content": [
|
536 |
-
{"type": "text", "text": message},
|
537 |
-
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
538 |
-
]
|
539 |
-
})
|
540 |
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
tools=tools,
|
545 |
-
max_tokens=300
|
546 |
-
)
|
547 |
|
548 |
-
# Check if a tool call is required based on the response
|
549 |
-
if response.choices[0].finish_reason == "tool_calls":
|
550 |
-
message_obj = response.choices[0].message
|
551 |
-
messages.append(message_obj)
|
552 |
-
|
553 |
-
# Process each tool call from the message
|
554 |
-
for tool_call in message_obj.tool_calls:
|
555 |
-
if tool_call.function.name == "detect_objects":
|
556 |
-
results = detect_objects(query)
|
557 |
-
else:
|
558 |
-
results = identify_plant()
|
559 |
-
|
560 |
-
tool_response = {
|
561 |
-
"role": "tool",
|
562 |
-
"content": json.dumps(results),
|
563 |
-
"tool_call_id": tool_call.id
|
564 |
-
}
|
565 |
-
messages.append(tool_response)
|
566 |
-
|
567 |
-
response = openai.chat.completions.create(
|
568 |
-
model=MODEL,
|
569 |
-
messages=messages,
|
570 |
-
max_tokens=300
|
571 |
-
)
|
572 |
-
|
573 |
-
return response.choices[0].message.content, state.last_prediction
|
574 |
-
|
575 |
-
# Create Gradio interface
|
576 |
with gr.Blocks() as demo:
|
577 |
-
gr.Markdown("#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
578 |
|
579 |
with gr.Row():
|
580 |
-
with gr.Column():
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
)
|
586 |
-
with gr.Row():
|
587 |
-
submit_btn = gr.Button("Analyze")
|
588 |
-
reset_btn = gr.Button("Reset")
|
589 |
-
|
590 |
-
with gr.Column():
|
591 |
-
chatbot = gr.Chatbot()
|
592 |
-
output_image = gr.Image(type="numpy", label="Detection Results")
|
593 |
-
|
594 |
-
def process_interaction(message, image, history):
|
595 |
-
response_text, pred_image = chat(message, image, history)
|
596 |
-
history.append((message, response_text))
|
597 |
-
return "", pred_image, history
|
598 |
-
|
599 |
-
def reset_interface():
|
600 |
-
state.current_image = None
|
601 |
-
state.last_prediction = None
|
602 |
-
return None, None, None, []
|
603 |
-
|
604 |
-
submit_btn.click(
|
605 |
-
fn=process_interaction,
|
606 |
-
inputs=[text_input, image_input, chatbot],
|
607 |
-
outputs=[text_input, output_image, chatbot]
|
608 |
-
)
|
609 |
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
|
627 |
-
demo.launch(share=True)
|
|
|
328 |
|
329 |
# demo.launch(share=True)
|
330 |
|
|
|
331 |
import os
|
332 |
+
import re
|
333 |
+
import io
|
334 |
+
import uuid
|
335 |
+
import contextlib
|
|
|
336 |
import gradio as gr
|
337 |
+
from PIL import Image
|
338 |
+
import shutil
|
339 |
+
|
340 |
+
# Required packages:
|
341 |
+
# pip install vision-agent gradio openai anthropic
|
342 |
+
|
343 |
from vision_agent.agent import VisionAgentCoderV2
|
344 |
from vision_agent.models import AgentMessage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
+
#############################################
|
347 |
+
# GLOBAL INITIALIZATION
|
348 |
+
#############################################
|
349 |
+
|
350 |
+
# Create a unique temporary directory for saved images
|
351 |
+
TEMP_DIR = "temp_images"
|
352 |
+
if not os.path.exists(TEMP_DIR):
|
353 |
+
os.makedirs(TEMP_DIR)
|
354 |
+
|
355 |
+
# Initialize VisionAgentCoderV2 with verbose logging so the generated code has detailed print outputs.
|
356 |
+
agent = VisionAgentCoderV2(verbose=True)
|
357 |
+
|
358 |
+
#############################################
|
359 |
+
# UTILITY: SAVE UPLOADED IMAGE TO A TEMP FILE
|
360 |
+
#############################################
|
361 |
+
|
362 |
+
def save_uploaded_image(image):
|
363 |
+
"""
|
364 |
+
Saves the uploaded image (a numpy array) to a temporary file.
|
365 |
+
Returns the filename (including path) to be passed as media to VisionAgent.
|
366 |
+
"""
|
367 |
+
# Generate a unique filename
|
368 |
+
filename = os.path.join(TEMP_DIR, f"{uuid.uuid4().hex}.jpg")
|
369 |
+
im = Image.fromarray(image)
|
370 |
+
im.save(filename)
|
371 |
+
return filename
|
372 |
+
|
373 |
+
#############################################
|
374 |
+
# UTILITY: PARSE FILENAMES FROM save_image(...)
|
375 |
+
#############################################
|
376 |
+
|
377 |
+
def parse_saved_image_filenames(code_str):
|
378 |
+
"""
|
379 |
+
Find all filenames in lines that look like:
|
380 |
+
save_image(..., 'filename.jpg')
|
381 |
+
Returns a list of the extracted filenames.
|
382 |
+
"""
|
383 |
+
pattern = r"save_image\s*\(\s*[^,]+,\s*'([^']+)'\s*\)"
|
384 |
+
return re.findall(pattern, code_str)
|
385 |
+
|
386 |
+
#############################################
|
387 |
+
# UTILITY: EXECUTE CODE, CAPTURE STDOUT, IDENTIFY IMAGES
|
388 |
+
#############################################
|
389 |
+
|
390 |
+
def run_and_capture_with_images(code_str):
|
391 |
+
"""
|
392 |
+
Executes the given code_str, capturing stdout and returning:
|
393 |
+
- output: a string with all print statements (the step logs)
|
394 |
+
- existing_images: list of filenames that were saved and exist on disk.
|
395 |
+
"""
|
396 |
+
# Parse the code for image filenames saved via save_image
|
397 |
+
filenames = parse_saved_image_filenames(code_str)
|
398 |
+
|
399 |
+
# Capture stdout using a StringIO buffer
|
400 |
+
buf = io.StringIO()
|
401 |
+
with contextlib.redirect_stdout(buf):
|
402 |
+
# IMPORTANT: Here we exec the generated code.
|
403 |
+
exec(code_str, globals(), locals())
|
404 |
+
|
405 |
+
# Gather all printed output
|
406 |
+
output = buf.getvalue()
|
407 |
+
|
408 |
+
# Check which of the parsed filenames exist on disk (prepend TEMP_DIR if needed)
|
409 |
+
existing_images = []
|
410 |
+
for fn in filenames:
|
411 |
+
# If filename is not an absolute path, assume it is in TEMP_DIR
|
412 |
+
if not os.path.isabs(fn):
|
413 |
+
fn = os.path.join(TEMP_DIR, fn)
|
414 |
+
if os.path.exists(fn):
|
415 |
+
existing_images.append(fn)
|
416 |
+
return output, existing_images
|
417 |
+
|
418 |
+
#############################################
|
419 |
+
# CHAT FUNCTION: PROCESS USER PROMPT & IMAGE
|
420 |
+
#############################################
|
421 |
+
|
422 |
+
def chat(prompt, image, history):
|
423 |
+
"""
|
424 |
+
When the user sends a prompt and optionally an image, do the following:
|
425 |
+
1. Save the image to a temp file.
|
426 |
+
2. Use VisionAgentCoderV2 to generate code for the task.
|
427 |
+
3. Execute the generated code, capturing its stdout logs and any saved image files.
|
428 |
+
4. Append the logs and image gallery info to the conversation history.
|
429 |
+
"""
|
430 |
+
# Validate that an image was provided.
|
431 |
+
if image is None:
|
432 |
+
history.append(("System", "Please upload an image."))
|
433 |
+
return history, None
|
434 |
+
|
435 |
+
# Save the uploaded image for use in the generated code.
|
436 |
+
image_path = save_uploaded_image(image)
|
437 |
+
|
438 |
+
# Generate the code with VisionAgent using the user prompt and the image filename.
|
439 |
+
code_context = agent.generate_code(
|
440 |
+
[
|
441 |
+
AgentMessage(
|
442 |
+
role="user",
|
443 |
+
content=prompt,
|
444 |
+
media=[image_path]
|
445 |
+
)
|
446 |
+
]
|
447 |
+
)
|
448 |
|
449 |
+
# Combine the generated code and its test snippet.
|
450 |
+
generated_code = code_context.code + "\n" + code_context.test
|
451 |
+
|
452 |
+
# Run the generated code and capture output and any saved images.
|
453 |
+
stdout_text, image_files = run_and_capture_with_images(generated_code)
|
454 |
|
455 |
+
# Format the response text (the captured logs).
|
456 |
+
response_text = f"**Execution Logs:**\n{stdout_text}\n"
|
457 |
+
if image_files:
|
458 |
+
response_text += "\n**Saved Images:** " + ", ".join(image_files)
|
459 |
+
else:
|
460 |
+
response_text += "\nNo images were saved by the generated code."
|
461 |
|
462 |
+
# Append the prompt and response to the chat history.
|
463 |
+
history.append((prompt, response_text))
|
|
|
|
|
464 |
|
465 |
+
# Optionally, you could clear the image input after use.
|
466 |
+
return history, image_files
|
|
|
|
|
|
|
|
|
|
|
467 |
|
468 |
+
#############################################
|
469 |
+
# GRADIO CHAT INTERFACE
|
470 |
+
#############################################
|
|
|
|
|
|
|
471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
with gr.Blocks() as demo:
|
473 |
+
gr.Markdown("# VisionAgent Chat App")
|
474 |
+
gr.Markdown(
|
475 |
+
"""
|
476 |
+
This chat app lets you enter a prompt (e.g., "Count the number of cacao oranges in the image")
|
477 |
+
along with an image. The app then uses VisionAgentCoderV2 to generate multi-step code, executes it,
|
478 |
+
and returns the detailed logs and any saved images.
|
479 |
+
"""
|
480 |
+
)
|
481 |
|
482 |
with gr.Row():
|
483 |
+
with gr.Column(scale=7):
|
484 |
+
chatbot = gr.Chatbot(label="Chat History")
|
485 |
+
prompt_input = gr.Textbox(label="Enter Prompt", placeholder="e.g., Count the number of cacao oranges in the image")
|
486 |
+
submit_btn = gr.Button("Send")
|
487 |
+
with gr.Column(scale=5):
|
488 |
+
image_input = gr.Image(label="Upload Image", type="numpy")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
489 |
|
490 |
+
gallery = gr.Gallery(label="Generated Images").style(grid=[2], height="auto")
|
491 |
+
|
492 |
+
# Clear chat history button
|
493 |
+
clear_btn = gr.Button("Clear Chat")
|
494 |
+
|
495 |
+
# Chat function wrapper (it takes current chat history, prompt, image)
|
496 |
+
def user_chat_wrapper(prompt, image, history):
|
497 |
+
history = history or []
|
498 |
+
history, image_files = chat(prompt, image, history)
|
499 |
+
return history, image_files
|
500 |
|
501 |
+
submit_btn.click(fn=user_chat_wrapper, inputs=[prompt_input, image_input, chatbot], outputs=[chatbot, gallery])
|
502 |
+
|
503 |
+
clear_btn.click(lambda: ([], None), None, [chatbot, gallery])
|
504 |
+
|
505 |
+
demo.launch()
|
506 |
|
|