File size: 4,503 Bytes
68a69f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeaceee
68a69f9
 
 
 
 
 
 
 
 
 
aeaceee
 
d9e54e4
68a69f9
1ae0dad
aeaceee
 
 
 
 
 
 
 
 
 
68a69f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aeaceee
 
 
68a69f9
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import random
import numpy as np
from PIL import Image
from collections import defaultdict
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode, Visualizer
from color_palette import ade_palette
from transformers import MaskFormerImageProcessor, Mask2FormerForUniversalSegmentation

def load_model_and_processor(model_ckpt: str):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
    model.eval()
    image_preprocessor = MaskFormerImageProcessor.from_pretrained(model_ckpt)
    return model, image_preprocessor

def load_default_ckpt(segmentation_task: str):
    if segmentation_task == "semantic":
        default_pretrained_ckpt = "facebook/mask2former-swin-tiny-ade-semantic"
    elif segmentation_task == "instance":
        default_pretrained_ckpt = "facebook/mask2former-swin-small-coco-instance"
    else:
        default_pretrained_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
    return default_pretrained_ckpt

def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
    metadata = MetadataCatalog.get("coco_2017_val_panoptic")
    for res in seg_info:
        res['category_id'] = res.pop('label_id')
        pred_class = res['category_id']
        isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
        res['isthing'] = bool(isthing)

    visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
    out = visualizer.draw_panoptic_seg_predictions(
    predicted_segmentation_map.cpu(), seg_info, alpha=0.5
    )
    output_img = Image.fromarray(out.get_image())
    return output_img

def draw_semantic_segmentation(segmentation_map, image, palette):
    
    color_segmentation_map = np.zeros((segmentation_map.shape[0], segmentation_map.shape[1], 3), dtype=np.uint8) # height, width, 3
    for label, color in enumerate(palette):
        color_segmentation_map[segmentation_map - 1 == label, :] = color
    # Convert to BGR
    ground_truth_color_seg = color_segmentation_map[..., ::-1]

    img = np.array(image) * 0.5 + ground_truth_color_seg * 0.5
    img = img.astype(np.uint8)
    return img

def visualize_instance_seg_mask(mask, input_image):
    color_segmentation_map = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    print("segmentation_map:",mask)
    labels = np.unique(mask)
    print("labels:",labels)
    label2color = {int(label): (random.randint(0, 1), random.randint(0, 255), random.randint(0, 255)) for label in labels}

    for label, color in label2color.items():
        color_segmentation_map[mask - 1 == label, :] = color

    ground_truth_color_seg = color_segmentation_map[..., ::-1]

    img = np.array(input_image) * 0.5 + ground_truth_color_seg * 0.5
    img = img.astype(np.uint8)
    return img

def predict_masks(input_img_path: str, segmentation_task: str):
    
    #load model and image processor
    default_pretrained_ckpt = load_default_ckpt(segmentation_task)
    model, image_processor = load_model_and_processor(default_pretrained_ckpt)
    
    ## pass input image through image processor
    image = Image.open(input_img_path)
    inputs = image_processor(images=image, return_tensors="pt")
    
    ## pass inputs to model for prediction
    with torch.no_grad():
        outputs = model(**inputs)
    
    # pass outputs to processor for postprocessing
    if segmentation_task == "semantic":
        result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
        predicted_segmentation_map = result.cpu().numpy()
        palette = ade_palette()
        output_result = draw_semantic_segmentation(predicted_segmentation_map, image, palette)

    elif segmentation_task == "instance":
        result = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
        predicted_instance_map = result["segmentation"].cpu().detach().numpy()
        output_result = visualize_instance_seg_mask(predicted_instance_map, image)

    else:
        result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
        predicted_segmentation_map = result["segmentation"]
        seg_info = result['segments_info']
        output_result = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)

    return output_result