File size: 1,475 Bytes
035e155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from transformers.modeling_outputs import SemanticSegmenterOutput
from transformers.feature_extraction_utils import BatchFeature
from PIL import Image
import torch
import torch.nn.functional as F
import numpy as np
import seaborn as sns
import itertools


def create_model():
    return SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")


def create_feature_extractor():
    return SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")


def postprocess(masks, height, width):
    masks = F.interpolate(masks, (height, width))

    label_per_pixel = torch.argmax(
        masks.squeeze(), dim=0).detach().numpy()
    color_mask = np.zeros(label_per_pixel.shape + (3,))
    palette = itertools.cycle(sns.color_palette())

    for lbl in np.unique(label_per_pixel):
        color_mask[label_per_pixel == lbl, :] = np.asarray(next(palette)) * 255

    return color_mask


def segment(image: Image, model, feature_extractor) -> torch.Tensor:
    inputs = feature_extractor(
        images=image, return_tensors="pt")
    outputs = model(**inputs)
    masks = outputs.logits

    color_mask = postprocess(masks, image.height, image.width)
    pred_img = np.array(image.convert('RGB')) * 0.25 + color_mask * 0.75
    pred_img = pred_img.astype(np.uint8)

    return pred_img