Spaces:
Sleeping
Sleeping
File size: 1,860 Bytes
66aead2 e6b5e51 66aead2 e6b5e51 66aead2 e6b5e51 66aead2 e6b5e51 66aead2 e6b5e51 66aead2 a0ff18d 66aead2 e6b5e51 66aead2 a0ff18d 66aead2 a0ff18d 66aead2 e6b5e51 66aead2 e6b5e51 66aead2 e6b5e51 66aead2 a0ff18d 66aead2 e6b5e51 66aead2 a0ff18d e6b5e51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from functools import lru_cache
from typing import Optional, Tuple
import numpy as np
from PIL import Image
from PIL.Image import Resampling
from huggingface_hub import hf_hub_download
from encode import rgb_encode
from image import ImageTyping, load_image
from onnxruntime_ import open_onnx_model
__all__ = [
'get_monochrome_score',
'is_monochrome',
]
# _DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
_MONOCHROME_CKPTS = [
'mobilenetv3_large_100_safe2',
'mobilenetv3_large_100',
'caformer_s36',
]
_DEFAULT_MONOCHROME_CKPT = _MONOCHROME_CKPTS[0]
@lru_cache()
def _monochrome_validate_model(model):
return open_onnx_model(hf_hub_download(
f'deepghs/monochrome_detect',
f'{model}/model.onnx'
))
def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
if image.mode != 'RGB':
image = image.convert('RGB')
image = image.resize(size, Resampling.BILINEAR)
data = rgb_encode(image, order_='CHW')
if normalize is not None:
mean_, std_ = normalize
mean = np.asarray([mean_]).reshape((-1, 1, 1))
std = np.asarray([std_]).reshape((-1, 1, 1))
data = (data - mean) / std
return data
def get_monochrome_score(image: ImageTyping, model: str = _DEFAULT_MONOCHROME_CKPT):
image = load_image(image, mode='RGB')
input_data = _2d_encode(image).astype(np.float32)
input_data = np.stack([input_data])
output_data, = _monochrome_validate_model(model).run(['output'], {'input': input_data})
return {name: v.item() for name, v in zip(['monochrome', 'normal'], output_data[0])}
def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool:
return get_monochrome_score(image, ckpt) >= threshold
|