|  | from typing import Any, Dict, Optional, Tuple, Type | 
					
						
						|  | from pydantic import BaseModel, Field | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from langchain_core.callbacks import ( | 
					
						
						|  | AsyncCallbackManagerForToolRun, | 
					
						
						|  | CallbackManagerForToolRun, | 
					
						
						|  | ) | 
					
						
						|  | from langchain_core.tools import BaseTool | 
					
						
						|  |  | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  | from transformers import ( | 
					
						
						|  | BertTokenizer, | 
					
						
						|  | ViTImageProcessor, | 
					
						
						|  | VisionEncoderDecoderModel, | 
					
						
						|  | GenerationConfig, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ChestXRayInput(BaseModel): | 
					
						
						|  | """Input for chest X-ray analysis tools. Only supports JPG or PNG images.""" | 
					
						
						|  |  | 
					
						
						|  | image_path: str = Field( | 
					
						
						|  | ..., description="Path to the radiology image file, only supports JPG or PNG images" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ChestXRayReportGeneratorTool(BaseTool): | 
					
						
						|  | """Tool that generates comprehensive chest X-ray reports with both findings and impressions. | 
					
						
						|  |  | 
					
						
						|  | This tool uses two Vision-Encoder-Decoder models (ViT-BERT) trained on CheXpert | 
					
						
						|  | and MIMIC-CXR datasets to generate structured radiology reports. It automatically | 
					
						
						|  | generates both detailed findings and impression summaries for each chest X-ray, | 
					
						
						|  | following standard radiological reporting format. | 
					
						
						|  |  | 
					
						
						|  | The tool uses: | 
					
						
						|  | - Findings model: Generates detailed observations of all visible structures | 
					
						
						|  | - Impression model: Provides concise clinical interpretation and key diagnoses | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | name: str = "chest_xray_report_generator" | 
					
						
						|  | description: str = ( | 
					
						
						|  | "A tool that analyzes chest X-ray images and generates comprehensive radiology reports " | 
					
						
						|  | "containing both detailed findings and impression summaries. Input should be the path " | 
					
						
						|  | "to a chest X-ray image file. Output is a structured report with both detailed " | 
					
						
						|  | "observations and key clinical conclusions." | 
					
						
						|  | ) | 
					
						
						|  | device: Optional[str] = "cuda" | 
					
						
						|  | args_schema: Type[BaseModel] = ChestXRayInput | 
					
						
						|  | findings_model: VisionEncoderDecoderModel = None | 
					
						
						|  | impression_model: VisionEncoderDecoderModel = None | 
					
						
						|  | findings_tokenizer: BertTokenizer = None | 
					
						
						|  | impression_tokenizer: BertTokenizer = None | 
					
						
						|  | findings_processor: ViTImageProcessor = None | 
					
						
						|  | impression_processor: ViTImageProcessor = None | 
					
						
						|  | generation_args: Dict[str, Any] = None | 
					
						
						|  |  | 
					
						
						|  | def __init__(self, cache_dir: str = "/model-weights", device: Optional[str] = "cuda"): | 
					
						
						|  | """Initialize the ChestXRayReportGeneratorTool with both findings and impression models.""" | 
					
						
						|  | super().__init__() | 
					
						
						|  | self.device = torch.device(device) if device else "cuda" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.findings_model = VisionEncoderDecoderModel.from_pretrained( | 
					
						
						|  | "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir | 
					
						
						|  | ).eval() | 
					
						
						|  | self.findings_tokenizer = BertTokenizer.from_pretrained( | 
					
						
						|  | "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir | 
					
						
						|  | ) | 
					
						
						|  | self.findings_processor = ViTImageProcessor.from_pretrained( | 
					
						
						|  | "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.impression_model = VisionEncoderDecoderModel.from_pretrained( | 
					
						
						|  | "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir | 
					
						
						|  | ).eval() | 
					
						
						|  | self.impression_tokenizer = BertTokenizer.from_pretrained( | 
					
						
						|  | "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir | 
					
						
						|  | ) | 
					
						
						|  | self.impression_processor = ViTImageProcessor.from_pretrained( | 
					
						
						|  | "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.findings_model = self.findings_model.to(self.device) | 
					
						
						|  | self.impression_model = self.impression_model.to(self.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.generation_args = { | 
					
						
						|  | "num_return_sequences": 1, | 
					
						
						|  | "max_length": 128, | 
					
						
						|  | "use_cache": True, | 
					
						
						|  | "beam_width": 2, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | def _process_image( | 
					
						
						|  | self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """Process the input image for a specific model. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | image_path (str): Path to the input image. | 
					
						
						|  | processor: Image processor for the specific model. | 
					
						
						|  | model: The model to process the image for. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | torch.Tensor: Processed image tensor ready for model input. | 
					
						
						|  | """ | 
					
						
						|  | image = Image.open(image_path).convert("RGB") | 
					
						
						|  | pixel_values = processor(image, return_tensors="pt").pixel_values | 
					
						
						|  |  | 
					
						
						|  | expected_size = model.config.encoder.image_size | 
					
						
						|  | actual_size = pixel_values.shape[-1] | 
					
						
						|  |  | 
					
						
						|  | if expected_size != actual_size: | 
					
						
						|  | pixel_values = torch.nn.functional.interpolate( | 
					
						
						|  | pixel_values, | 
					
						
						|  | size=(expected_size, expected_size), | 
					
						
						|  | mode="bilinear", | 
					
						
						|  | align_corners=False, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | pixel_values = pixel_values.to(self.device) | 
					
						
						|  |  | 
					
						
						|  | return pixel_values | 
					
						
						|  |  | 
					
						
						|  | def _generate_report_section( | 
					
						
						|  | self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer | 
					
						
						|  | ) -> str: | 
					
						
						|  | """Generate a report section using the specified model. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | pixel_values: Processed image tensor. | 
					
						
						|  | model: The model to use for generation. | 
					
						
						|  | tokenizer: The tokenizer for the model. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | str: Generated text for the report section. | 
					
						
						|  | """ | 
					
						
						|  | generation_config = GenerationConfig( | 
					
						
						|  | **{ | 
					
						
						|  | **self.generation_args, | 
					
						
						|  | "bos_token_id": model.config.bos_token_id, | 
					
						
						|  | "eos_token_id": model.config.eos_token_id, | 
					
						
						|  | "pad_token_id": model.config.pad_token_id, | 
					
						
						|  | "decoder_start_token_id": tokenizer.cls_token_id, | 
					
						
						|  | } | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | generated_ids = model.generate(pixel_values, generation_config=generation_config) | 
					
						
						|  |  | 
					
						
						|  | return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | 
					
						
						|  |  | 
					
						
						|  | def _run( | 
					
						
						|  | self, | 
					
						
						|  | image_path: str, | 
					
						
						|  | run_manager: Optional[CallbackManagerForToolRun] = None, | 
					
						
						|  | ) -> Tuple[str, Dict]: | 
					
						
						|  | """Generate a comprehensive chest X-ray report containing both findings and impression. | 
					
						
						|  |  | 
					
						
						|  | Args: | 
					
						
						|  | image_path (str): The path to the chest X-ray image file. | 
					
						
						|  | run_manager (Optional[CallbackManagerForToolRun]): The callback manager. | 
					
						
						|  |  | 
					
						
						|  | Returns: | 
					
						
						|  | Tuple[str, Dict]: A tuple containing the complete report and metadata. | 
					
						
						|  | """ | 
					
						
						|  | try: | 
					
						
						|  |  | 
					
						
						|  | findings_pixels = self._process_image( | 
					
						
						|  | image_path, self.findings_processor, self.findings_model | 
					
						
						|  | ) | 
					
						
						|  | impression_pixels = self._process_image( | 
					
						
						|  | image_path, self.impression_processor, self.impression_model | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with torch.inference_mode(): | 
					
						
						|  | findings_text = self._generate_report_section( | 
					
						
						|  | findings_pixels, self.findings_model, self.findings_tokenizer | 
					
						
						|  | ) | 
					
						
						|  | impression_text = self._generate_report_section( | 
					
						
						|  | impression_pixels, self.impression_model, self.impression_tokenizer | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | report = ( | 
					
						
						|  | "CHEST X-RAY REPORT\n\n" | 
					
						
						|  | f"FINDINGS:\n{findings_text}\n\n" | 
					
						
						|  | f"IMPRESSION:\n{impression_text}" | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | metadata = { | 
					
						
						|  | "image_path": image_path, | 
					
						
						|  | "analysis_status": "completed", | 
					
						
						|  | "sections_generated": ["findings", "impression"], | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | return report, metadata | 
					
						
						|  |  | 
					
						
						|  | except Exception as e: | 
					
						
						|  | return f"Error generating report: {str(e)}", { | 
					
						
						|  | "image_path": image_path, | 
					
						
						|  | "analysis_status": "failed", | 
					
						
						|  | "error": str(e), | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | async def _arun( | 
					
						
						|  | self, | 
					
						
						|  | image_path: str, | 
					
						
						|  | run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | 
					
						
						|  | ) -> Tuple[str, Dict]: | 
					
						
						|  | """Asynchronously generate a comprehensive chest X-ray report.""" | 
					
						
						|  | return self._run(image_path) | 
					
						
						|  |  |