|  | import json | 
					
						
						|  | import openai | 
					
						
						|  | import os | 
					
						
						|  | from datetime import datetime | 
					
						
						|  | import base64 | 
					
						
						|  | import logging | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | import time | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  | from typing import Dict, List, Optional, Union, Any | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | DEBUG_MODE = False | 
					
						
						|  | OUTPUT_DIR = "results" | 
					
						
						|  | MODEL_NAME = "gpt-4o-2024-05-13" | 
					
						
						|  | TEMPERATURE = 0.2 | 
					
						
						|  | SUBSET = "Visual Question Answering" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO | 
					
						
						|  | logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s") | 
					
						
						|  | logger = logging.getLogger(__name__) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_mime_type(file_path: str) -> str: | 
					
						
						|  | """ | 
					
						
						|  | Determine MIME type based on file extension. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | file_path (str): Path to the file | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | str: MIME type string for the file | 
					
						
						|  | """ | 
					
						
						|  | extension = os.path.splitext(file_path)[1].lower() | 
					
						
						|  | mime_types = { | 
					
						
						|  | ".png": "image/png", | 
					
						
						|  | ".jpg": "image/jpeg", | 
					
						
						|  | ".jpeg": "image/jpeg", | 
					
						
						|  | ".gif": "image/gif", | 
					
						
						|  | } | 
					
						
						|  | return mime_types.get(extension, "application/octet-stream") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def encode_image(image_path: str) -> str: | 
					
						
						|  | """ | 
					
						
						|  | Encode image to base64 with extensive error checking. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | image_path (str): Path to the image file | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | str: Base64 encoded image string | 
					
						
						|  |  | 
					
						
						|  | Raises: | 
					
						
						|  | FileNotFoundError: If image file does not exist | 
					
						
						|  | ValueError: If image file is empty or too large | 
					
						
						|  | Exception: For other image processing errors | 
					
						
						|  | """ | 
					
						
						|  | logger.debug(f"Attempting to read image from: {image_path}") | 
					
						
						|  | if not os.path.exists(image_path): | 
					
						
						|  | raise FileNotFoundError(f"Image file not found: {image_path}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | file_size = os.path.getsize(image_path) | 
					
						
						|  | if file_size > 20 * 1024 * 1024: | 
					
						
						|  | raise ValueError("Image file size exceeds 20MB limit") | 
					
						
						|  | if file_size == 0: | 
					
						
						|  | raise ValueError("Image file is empty") | 
					
						
						|  | logger.debug(f"Image file size: {file_size / 1024:.2f} KB") | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with Image.open(image_path) as img: | 
					
						
						|  |  | 
					
						
						|  | width, height = img.size | 
					
						
						|  | format = img.format | 
					
						
						|  | mode = img.mode | 
					
						
						|  | logger.debug( | 
					
						
						|  | f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if format not in ["PNG", "JPEG", "GIF"]: | 
					
						
						|  | raise ValueError(f"Unsupported image format: {format}") | 
					
						
						|  |  | 
					
						
						|  | with open(image_path, "rb") as image_file: | 
					
						
						|  |  | 
					
						
						|  | header = image_file.read(8) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | image_file.seek(0) | 
					
						
						|  | encoded = base64.b64encode(image_file.read()).decode("utf-8") | 
					
						
						|  | encoded_length = len(encoded) | 
					
						
						|  | logger.debug(f"Base64 encoded length: {encoded_length} characters") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if encoded_length == 0: | 
					
						
						|  | raise ValueError("Base64 encoding produced empty string") | 
					
						
						|  | if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"): | 
					
						
						|  | logger.warning("Base64 string doesn't start with expected JPEG or PNG header") | 
					
						
						|  |  | 
					
						
						|  | return encoded | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Error reading/encoding image: {str(e)}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_single_request( | 
					
						
						|  | image_path: str, question: str, options: Dict[str, str] | 
					
						
						|  | ) -> List[Dict[str, Any]]: | 
					
						
						|  | """ | 
					
						
						|  | Create a single API request with image and question. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | image_path (str): Path to the image file | 
					
						
						|  | question (str): Question text | 
					
						
						|  | options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1' | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | List[Dict[str, Any]]: List of message dictionaries for the API request | 
					
						
						|  |  | 
					
						
						|  | Raises: | 
					
						
						|  | Exception: For errors in request creation | 
					
						
						|  | """ | 
					
						
						|  | if DEBUG_MODE: | 
					
						
						|  | logger.debug("Creating API request...") | 
					
						
						|  |  | 
					
						
						|  | prompt = f"""Given the following medical examination question: | 
					
						
						|  | Please answer this multiple choice question: | 
					
						
						|  |  | 
					
						
						|  | Question: {question} | 
					
						
						|  |  | 
					
						
						|  | Options: | 
					
						
						|  | A) {options['option_0']} | 
					
						
						|  | B) {options['option_1']} | 
					
						
						|  |  | 
					
						
						|  | Base your answer only on the provided image and select either A or B.""" | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | encoded_image = encode_image(image_path) | 
					
						
						|  | mime_type = get_mime_type(image_path) | 
					
						
						|  |  | 
					
						
						|  | if DEBUG_MODE: | 
					
						
						|  | logger.debug(f"Image encoded with MIME type: {mime_type}") | 
					
						
						|  |  | 
					
						
						|  | messages = [ | 
					
						
						|  | { | 
					
						
						|  | "role": "system", | 
					
						
						|  | "content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.", | 
					
						
						|  | }, | 
					
						
						|  | { | 
					
						
						|  | "role": "user", | 
					
						
						|  | "content": [ | 
					
						
						|  | {"type": "text", "text": prompt}, | 
					
						
						|  | { | 
					
						
						|  | "type": "image_url", | 
					
						
						|  | "image_url": {"url": f"data:{mime_type};base64,{encoded_image}"}, | 
					
						
						|  | }, | 
					
						
						|  | ], | 
					
						
						|  | }, | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | if DEBUG_MODE: | 
					
						
						|  | log_messages = json.loads(json.dumps(messages)) | 
					
						
						|  | log_messages[1]["content"][1]["image_url"][ | 
					
						
						|  | "url" | 
					
						
						|  | ] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]" | 
					
						
						|  | logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}") | 
					
						
						|  |  | 
					
						
						|  | return messages | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Error creating request: {str(e)}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def check_answer(model_answer: str, correct_answer: int) -> bool: | 
					
						
						|  | """ | 
					
						
						|  | Check if the model's answer matches the correct answer. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | model_answer (str): The model's answer (A or B) | 
					
						
						|  | correct_answer (int): The correct answer index (0 for A, 1 for B) | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | bool: True if answer is correct, False otherwise | 
					
						
						|  | """ | 
					
						
						|  | if not isinstance(model_answer, str): | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_letter = model_answer.strip().upper() | 
					
						
						|  | if model_letter.startswith("A"): | 
					
						
						|  | model_index = 0 | 
					
						
						|  | elif model_letter.startswith("B"): | 
					
						
						|  | model_index = 1 | 
					
						
						|  | else: | 
					
						
						|  | return False | 
					
						
						|  |  | 
					
						
						|  | return model_index == correct_answer | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str: | 
					
						
						|  | """ | 
					
						
						|  | Save results to a JSON file with timestamp. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | results (List[Dict[str, Any]]): List of result dictionaries | 
					
						
						|  | output_dir (str): Directory to save results | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | str: Path to the saved file | 
					
						
						|  | """ | 
					
						
						|  | Path(output_dir).mkdir(parents=True, exist_ok=True) | 
					
						
						|  | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | 
					
						
						|  | output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json") | 
					
						
						|  |  | 
					
						
						|  | with open(output_file, "w") as f: | 
					
						
						|  | json.dump(results, f, indent=2) | 
					
						
						|  |  | 
					
						
						|  | logger.info(f"Batch results saved to {output_file}") | 
					
						
						|  | return output_file | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]: | 
					
						
						|  | """ | 
					
						
						|  | Calculate accuracy from results, handling error cases. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | results (List[Dict[str, Any]]): List of result dictionaries | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total) | 
					
						
						|  | """ | 
					
						
						|  | if not results: | 
					
						
						|  | return 0.0, 0, 0 | 
					
						
						|  |  | 
					
						
						|  | total = len(results) | 
					
						
						|  | valid_results = [r for r in results if "output" in r] | 
					
						
						|  | correct = sum( | 
					
						
						|  | 1 for result in valid_results if result.get("output", {}).get("is_correct", False) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | accuracy = (correct / total * 100) if total > 0 else 0 | 
					
						
						|  | return accuracy, correct, total | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float: | 
					
						
						|  | """ | 
					
						
						|  | Calculate accuracy for the current batch. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | results (List[Dict[str, Any]]): List of result dictionaries | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | float: Accuracy percentage for the batch | 
					
						
						|  | """ | 
					
						
						|  | valid_results = [r for r in results if "output" in r] | 
					
						
						|  | if not valid_results: | 
					
						
						|  | return 0.0 | 
					
						
						|  | return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def process_batch( | 
					
						
						|  | data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50 | 
					
						
						|  | ) -> List[Dict[str, Any]]: | 
					
						
						|  | """ | 
					
						
						|  | Process a batch of examples and return results. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | data (List[Dict[str, Any]]): List of data items to process | 
					
						
						|  | client (openai.OpenAI): OpenAI client instance | 
					
						
						|  | start_idx (int, optional): Starting index for batch. Defaults to 0 | 
					
						
						|  | batch_size (int, optional): Size of batch to process. Defaults to 50 | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | List[Dict[str, Any]]: List of processed results | 
					
						
						|  | """ | 
					
						
						|  | batch_results = [] | 
					
						
						|  | end_idx = min(start_idx + batch_size, len(data)) | 
					
						
						|  |  | 
					
						
						|  | pbar = tqdm( | 
					
						
						|  | range(start_idx, end_idx), | 
					
						
						|  | desc=f"Processing batch {start_idx//batch_size + 1}", | 
					
						
						|  | unit="example", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | for index in pbar: | 
					
						
						|  | vqa_item = data[index] | 
					
						
						|  | options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]} | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | messages = create_single_request( | 
					
						
						|  | image_path=vqa_item["image_path"], question=vqa_item["question"], options=options | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | response = client.chat.completions.create( | 
					
						
						|  | model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | model_answer = response.choices[0].message.content.strip() | 
					
						
						|  | is_correct = check_answer(model_answer, vqa_item["answer"]) | 
					
						
						|  |  | 
					
						
						|  | result = { | 
					
						
						|  | "timestamp": datetime.now().isoformat(), | 
					
						
						|  | "example_index": index, | 
					
						
						|  | "input": { | 
					
						
						|  | "question": vqa_item["question"], | 
					
						
						|  | "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, | 
					
						
						|  | "image_path": vqa_item["image_path"], | 
					
						
						|  | }, | 
					
						
						|  | "output": { | 
					
						
						|  | "model_answer": model_answer, | 
					
						
						|  | "correct_answer": "A" if vqa_item["answer"] == 0 else "B", | 
					
						
						|  | "is_correct": is_correct, | 
					
						
						|  | "usage": { | 
					
						
						|  | "prompt_tokens": response.usage.prompt_tokens, | 
					
						
						|  | "completion_tokens": response.usage.completion_tokens, | 
					
						
						|  | "total_tokens": response.usage.total_tokens, | 
					
						
						|  | }, | 
					
						
						|  | }, | 
					
						
						|  | } | 
					
						
						|  | batch_results.append(result) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | current_accuracy = calculate_batch_accuracy(batch_results) | 
					
						
						|  | pbar.set_description( | 
					
						
						|  | f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% " | 
					
						
						|  | f"({len(batch_results)}/{index-start_idx+1} examples)" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | error_result = { | 
					
						
						|  | "timestamp": datetime.now().isoformat(), | 
					
						
						|  | "example_index": index, | 
					
						
						|  | "error": str(e), | 
					
						
						|  | "input": { | 
					
						
						|  | "question": vqa_item["question"], | 
					
						
						|  | "options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]}, | 
					
						
						|  | "image_path": vqa_item["image_path"], | 
					
						
						|  | }, | 
					
						
						|  | } | 
					
						
						|  | batch_results.append(error_result) | 
					
						
						|  | if DEBUG_MODE: | 
					
						
						|  | pbar.write(f"Error processing example {index}: {str(e)}") | 
					
						
						|  |  | 
					
						
						|  | time.sleep(1) | 
					
						
						|  |  | 
					
						
						|  | return batch_results | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main() -> None: | 
					
						
						|  | """ | 
					
						
						|  | Main function to process the entire dataset. | 
					
						
						|  |  | 
					
						
						|  | Raises: | 
					
						
						|  | ValueError: If OPENAI_API_KEY is not set | 
					
						
						|  | Exception: For other processing errors | 
					
						
						|  | """ | 
					
						
						|  | logger.info("Starting full dataset processing...") | 
					
						
						|  | json_path = "../data/chexbench_updated.json" | 
					
						
						|  |  | 
					
						
						|  | try: | 
					
						
						|  | api_key = os.getenv("OPENAI_API_KEY") | 
					
						
						|  | if not api_key: | 
					
						
						|  | raise ValueError("OPENAI_API_KEY environment variable is not set.") | 
					
						
						|  | client = openai.OpenAI(api_key=api_key) | 
					
						
						|  |  | 
					
						
						|  | with open(json_path, "r") as f: | 
					
						
						|  | data = json.load(f) | 
					
						
						|  |  | 
					
						
						|  | subset_data = data[SUBSET] | 
					
						
						|  | total_examples = len(subset_data) | 
					
						
						|  | logger.info(f"Found {total_examples} examples in {SUBSET} subset") | 
					
						
						|  |  | 
					
						
						|  | all_results = [] | 
					
						
						|  | batch_size = 50 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for start_idx in range(0, total_examples, batch_size): | 
					
						
						|  | batch_results = process_batch(subset_data, client, start_idx, batch_size) | 
					
						
						|  | all_results.extend(batch_results) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | output_file = save_results_to_json(all_results, OUTPUT_DIR) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | overall_accuracy, correct, total = calculate_accuracy(all_results) | 
					
						
						|  | logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed") | 
					
						
						|  | logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)") | 
					
						
						|  |  | 
					
						
						|  | logger.info("Processing completed!") | 
					
						
						|  | logger.info(f"Final results saved to: {output_file}") | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | logger.error(f"Fatal error: {str(e)}") | 
					
						
						|  | raise | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | main() | 
					
						
						|  |  |