Spaces:
Runtime error
Runtime error
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation | |
from transformers.modeling_outputs import SemanticSegmenterOutput | |
from transformers.feature_extraction_utils import BatchFeature | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
import seaborn as sns | |
import itertools | |
def create_model(): | |
return SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
def create_feature_extractor(): | |
return SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") | |
def postprocess(masks, height, width): | |
masks = F.interpolate(masks, (height, width)) | |
label_per_pixel = torch.argmax( | |
masks.squeeze(), dim=0).detach().numpy() | |
color_mask = np.zeros(label_per_pixel.shape + (3,)) | |
palette = itertools.cycle(sns.color_palette()) | |
for lbl in np.unique(label_per_pixel): | |
color_mask[label_per_pixel == lbl, :] = np.asarray(next(palette)) * 255 | |
return color_mask | |
def segment(image: Image, model, feature_extractor) -> torch.Tensor: | |
inputs = feature_extractor( | |
images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
masks = outputs.logits | |
color_mask = postprocess(masks, image.height, image.width) | |
pred_img = np.array(image.convert('RGB')) * 0.25 + color_mask * 0.75 | |
pred_img = pred_img.astype(np.uint8) | |
return pred_img | |