|
from typing import List, Union |
|
|
|
import numpy as np |
|
import torch |
|
from diffusers.modular_pipelines import ( |
|
ComponentSpec, |
|
InputParam, |
|
ModularPipelineBlocks, |
|
OutputParam, |
|
PipelineState, |
|
) |
|
from PIL import Image, ImageDraw |
|
from transformers import Florence2ForConditionalGeneration, AutoProcessor |
|
|
|
|
|
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks): |
|
@property |
|
def expected_components(self): |
|
return [ |
|
ComponentSpec( |
|
name="image_annotator", |
|
type_hint=Florence2ForConditionalGeneration, |
|
repo="florence-community/Florence-2-base-ft", |
|
), |
|
ComponentSpec( |
|
name="image_annotator_processor", |
|
type_hint=AutoProcessor, |
|
repo="florence-community/Florence-2-base-ft", |
|
), |
|
] |
|
|
|
@property |
|
def inputs(self) -> List[InputParam]: |
|
return [ |
|
InputParam( |
|
"image", |
|
type_hint=Union[Image.Image, List[Image.Image]], |
|
required=True, |
|
description="Image(s) to annotate", |
|
), |
|
InputParam( |
|
"annotation_task", |
|
type_hint=Union[str, List[str]], |
|
required=True, |
|
default="<REFERRING_EXPRESSION_SEGMENTATION>", |
|
description="""Annotation Task to perform on the image. |
|
Supported Tasks: |
|
|
|
<OD> |
|
<REFERRING_EXPRESSION_SEGMENTATION> |
|
<CAPTION> |
|
<DETAILED_CAPTION> |
|
<MORE_DETAILED_CAPTION> |
|
<DENSE_REGION_CAPTION> |
|
<CAPTION_TO_PHRASE_GROUNDING> |
|
<OPEN_VOCABULARY_DETECTION> |
|
|
|
""", |
|
), |
|
InputParam( |
|
"annotation_prompt", |
|
type_hint=Union[str, List[str]], |
|
required=True, |
|
description="""Annotation Prompt to provide more context to the task. |
|
Can be used to detect or segment out specific elements in the image |
|
""", |
|
), |
|
InputParam( |
|
"annotation_output_type", |
|
type_hint=str, |
|
required=True, |
|
default="mask_image", |
|
description="""Output type from annotation predictions. Availabe options are |
|
annotation: |
|
- raw annotation predictions from the model based on task type. |
|
mask_image: |
|
-black and white mask image for the given image based on the task type |
|
mask_overlay: |
|
- white mask overlayed on the original image |
|
bounding_box: |
|
- bounding boxes drawn on the original image |
|
""", |
|
), |
|
InputParam( |
|
"annotation_overlay", |
|
type_hint=bool, |
|
required=True, |
|
default=False, |
|
description="", |
|
), |
|
InputParam( |
|
"fill", |
|
type_hint=str, |
|
default="white", |
|
description="", |
|
), |
|
] |
|
|
|
@property |
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
return [ |
|
OutputParam( |
|
"mask_image", |
|
type_hint=Image, |
|
description="Inpainting Mask for input Image(s)", |
|
), |
|
OutputParam( |
|
"annotations", |
|
type_hint=dict, |
|
description="Annotations Predictions for input Image(s)", |
|
), |
|
OutputParam( |
|
"image", |
|
type_hint=Image, |
|
description="Annotated input Image(s)", |
|
), |
|
] |
|
|
|
def get_annotations(self, components, images, prompts, task): |
|
task_prompts = [task + prompt for prompt in prompts] |
|
|
|
inputs = components.image_annotator_processor( |
|
text=task_prompts, images=images, return_tensors="pt" |
|
).to(components.image_annotator.device, components.image_annotator.dtype) |
|
|
|
generated_ids = components.image_annotator.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
early_stopping=False, |
|
do_sample=False, |
|
num_beams=3, |
|
) |
|
annotations = components.image_annotator_processor.batch_decode( |
|
generated_ids, skip_special_tokens=False |
|
) |
|
outputs = [] |
|
for image, annotation in zip(images, annotations): |
|
outputs.append( |
|
components.image_annotator_processor.post_process_generation( |
|
annotation, task=task, image_size=(image.width, image.height) |
|
) |
|
) |
|
return outputs |
|
|
|
def prepare_mask(self, images, annotations, overlay=False, fill="white"): |
|
masks = [] |
|
for image, annotation in zip(images, annotations): |
|
mask_image = image.copy() if overlay else Image.new("L", image.size, 0) |
|
draw = ImageDraw.Draw(mask_image) |
|
|
|
for _, _annotation in annotation.items(): |
|
if "polygons" in _annotation: |
|
for polygon in _annotation["polygons"]: |
|
polygon = np.array(polygon).reshape(-1, 2) |
|
if len(polygon) < 3: |
|
continue |
|
polygon = polygon.reshape(-1).tolist() |
|
draw.polygon(polygon, fill=fill) |
|
|
|
elif "bbox" in _annotation: |
|
bbox = _annotation["bbox"] |
|
draw.rectangle(bbox, fill="white") |
|
|
|
masks.append(mask_image) |
|
|
|
return masks |
|
|
|
def prepare_bounding_boxes(self, images, annotations): |
|
outputs = [] |
|
for image, annotation in zip(images, annotations): |
|
image_copy = image.copy() |
|
draw = ImageDraw.Draw(image_copy) |
|
for _, _annotation in annotation.items(): |
|
bbox = _annotation["bbox"] |
|
label = _annotation["label"] |
|
|
|
draw.rectangle(bbox, outline="red", width=3) |
|
draw.text((bbox[0], bbox[1] - 20), label, fill="red") |
|
|
|
outputs.append(image_copy) |
|
|
|
return outputs |
|
|
|
def prepare_inputs(self, images, prompts): |
|
prompts = prompts or "" |
|
|
|
if isinstance(images, Image.Image): |
|
images = [images] |
|
if isinstance(prompts, str): |
|
prompts = [prompts] |
|
|
|
if len(images) != len(prompts): |
|
raise ValueError("Number of images and annotation prompts must match.") |
|
|
|
return images, prompts |
|
|
|
@torch.no_grad() |
|
def __call__(self, components, state: PipelineState) -> PipelineState: |
|
block_state = self.get_block_state(state) |
|
images, annotation_task_prompt = self.prepare_inputs( |
|
block_state.image, block_state.annotation_prompt |
|
) |
|
task = block_state.annotation_task |
|
fill = block_state.fill |
|
|
|
annotations = self.get_annotations( |
|
components, images, annotation_task_prompt, task |
|
) |
|
block_state.annotations = annotations |
|
if block_state.annotation_output_type == "mask_image": |
|
block_state.mask_image = self.prepare_mask(images, annotations) |
|
else: |
|
block_state.mask_image = None |
|
|
|
if block_state.annotation_output_type == "mask_overlay": |
|
block_state.image = self.prepare_mask( |
|
images, annotations, overlay=True, fill=fill |
|
) |
|
|
|
elif block_state.annotation_output_type == "bounding_box": |
|
block_state.image = self.prepare_bounding_boxes(images, annotations) |
|
|
|
self.set_block_state(state, block_state) |
|
|
|
return components, state |
|
|