from typing import List, Union import numpy as np import torch from diffusers.modular_pipelines import ( ComponentSpec, InputParam, ModularPipelineBlocks, OutputParam, PipelineState, ) from PIL import Image, ImageDraw from transformers import Florence2ForConditionalGeneration, AutoProcessor class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): @property def expected_components(self): return [ ComponentSpec( name="image_annotator", type_hint=Florence2ForConditionalGeneration, repo="florence-community/Florence-2-base-ft", ), ComponentSpec( name="image_annotator_processor", type_hint=AutoProcessor, repo="florence-community/Florence-2-base-ft", ), ] @property def inputs(self) -> List[InputParam]: return [ InputParam( "image", type_hint=Union[Image.Image, List[Image.Image]], required=True, description="Image(s) to annotate", ), InputParam( "annotation_task", type_hint=Union[str, List[str]], required=True, default="", description="""Annotation Task to perform on the image. Supported Tasks: """, ), InputParam( "annotation_prompt", type_hint=Union[str, List[str]], required=True, description="""Annotation Prompt to provide more context to the task. Can be used to detect or segment out specific elements in the image """, ), InputParam( "annotation_output_type", type_hint=str, required=True, default="mask_image", description="""Output type from annotation predictions. Availabe options are annotation: - raw annotation predictions from the model based on task type. mask_image: -black and white mask image for the given image based on the task type mask_overlay: - white mask overlayed on the original image bounding_box: - bounding boxes drawn on the original image """, ), InputParam( "annotation_overlay", type_hint=bool, required=True, default=False, description="", ), InputParam( "fill", type_hint=str, default="white", description="", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "mask_image", type_hint=Image, description="Inpainting Mask for input Image(s)", ), OutputParam( "annotations", type_hint=dict, description="Annotations Predictions for input Image(s)", ), OutputParam( "image", type_hint=Image, description="Annotated input Image(s)", ), ] def get_annotations(self, components, images, prompts, task): task_prompts = [task + prompt for prompt in prompts] inputs = components.image_annotator_processor( text=task_prompts, images=images, return_tensors="pt" ).to(components.image_annotator.device, components.image_annotator.dtype) generated_ids = components.image_annotator.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) annotations = components.image_annotator_processor.batch_decode( generated_ids, skip_special_tokens=False ) outputs = [] for image, annotation in zip(images, annotations): outputs.append( components.image_annotator_processor.post_process_generation( annotation, task=task, image_size=(image.width, image.height) ) ) return outputs def prepare_mask(self, images, annotations, overlay=False, fill="white"): masks = [] for image, annotation in zip(images, annotations): mask_image = image.copy() if overlay else Image.new("L", image.size, 0) draw = ImageDraw.Draw(mask_image) for _, _annotation in annotation.items(): if "polygons" in _annotation: for polygon in _annotation["polygons"]: polygon = np.array(polygon).reshape(-1, 2) if len(polygon) < 3: continue polygon = polygon.reshape(-1).tolist() draw.polygon(polygon, fill=fill) elif "bbox" in _annotation: bbox = _annotation["bbox"] draw.rectangle(bbox, fill="white") masks.append(mask_image) return masks def prepare_bounding_boxes(self, images, annotations): outputs = [] for image, annotation in zip(images, annotations): image_copy = image.copy() draw = ImageDraw.Draw(image_copy) for _, _annotation in annotation.items(): bbox = _annotation["bbox"] label = _annotation["label"] draw.rectangle(bbox, outline="red", width=3) draw.text((bbox[0], bbox[1] - 20), label, fill="red") outputs.append(image_copy) return outputs def prepare_inputs(self, images, prompts): prompts = prompts or "" if isinstance(images, Image.Image): images = [images] if isinstance(prompts, str): prompts = [prompts] if len(images) != len(prompts): raise ValueError("Number of images and annotation prompts must match.") return images, prompts @torch.no_grad() def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) images, annotation_task_prompt = self.prepare_inputs( block_state.image, block_state.annotation_prompt ) task = block_state.annotation_task fill = block_state.fill annotations = self.get_annotations( components, images, annotation_task_prompt, task ) block_state.annotations = annotations if block_state.annotation_output_type == "mask_image": block_state.mask_image = self.prepare_mask(images, annotations) else: block_state.mask_image = None if block_state.annotation_output_type == "mask_overlay": block_state.image = self.prepare_mask( images, annotations, overlay=True, fill=fill ) elif block_state.annotation_output_type == "bounding_box": block_state.image = self.prepare_bounding_boxes(images, annotations) self.set_block_state(state, block_state) return components, state