File size: 1,860 Bytes
66aead2
e6b5e51
66aead2
 
e6b5e51
 
66aead2
 
e6b5e51
66aead2
 
 
 
 
e6b5e51
66aead2
 
e6b5e51
66aead2
a0ff18d
 
 
66aead2
e6b5e51
66aead2
 
 
a0ff18d
66aead2
a0ff18d
 
66aead2
 
 
e6b5e51
 
 
 
 
 
66aead2
e6b5e51
 
 
 
 
66aead2
e6b5e51
66aead2
 
a0ff18d
66aead2
e6b5e51
66aead2
a0ff18d
 
e6b5e51
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from functools import lru_cache
from typing import Optional, Tuple

import numpy as np
from PIL import Image
from PIL.Image import Resampling
from huggingface_hub import hf_hub_download

from encode import rgb_encode
from image import ImageTyping, load_image
from onnxruntime_ import open_onnx_model

__all__ = [
    'get_monochrome_score',
    'is_monochrome',
]

# _DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
_MONOCHROME_CKPTS = [
    'mobilenetv3_large_100_safe2',
    'mobilenetv3_large_100',
    'caformer_s36',
]
_DEFAULT_MONOCHROME_CKPT = _MONOCHROME_CKPTS[0]


@lru_cache()
def _monochrome_validate_model(model):
    return open_onnx_model(hf_hub_download(
        f'deepghs/monochrome_detect',
        f'{model}/model.onnx'
    ))


def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
               normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = image.resize(size, Resampling.BILINEAR)
    data = rgb_encode(image, order_='CHW')

    if normalize is not None:
        mean_, std_ = normalize
        mean = np.asarray([mean_]).reshape((-1, 1, 1))
        std = np.asarray([std_]).reshape((-1, 1, 1))
        data = (data - mean) / std

    return data


def get_monochrome_score(image: ImageTyping, model: str = _DEFAULT_MONOCHROME_CKPT):
    image = load_image(image, mode='RGB')
    input_data = _2d_encode(image).astype(np.float32)
    input_data = np.stack([input_data])
    output_data, = _monochrome_validate_model(model).run(['output'], {'input': input_data})
    return {name: v.item() for name, v in zip(['monochrome', 'normal'], output_data[0])}


def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool:
    return get_monochrome_score(image, ckpt) >= threshold