narugo1992 commited on
Commit
e6b5e51
·
1 Parent(s): 66aead2

dev(narugo): save it

Browse files
Files changed (2) hide show
  1. encode.py +59 -0
  2. monochrome.py +26 -45
encode.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from image import load_image, ImageTyping
4
+
5
+ __all__ = [
6
+ 'rgb_encode',
7
+ ]
8
+
9
+ _DEFAULT_ORDER = 'HWC'
10
+
11
+
12
+ def _get_hwc_map(order_: str):
13
+ return tuple(_DEFAULT_ORDER.index(c) for c in order_.upper())
14
+
15
+
16
+ def rgb_encode(image: ImageTyping, order_: str = 'CHW', use_float: bool = True) -> np.ndarray:
17
+ """
18
+ Overview:
19
+ Encode image as rgb channels.
20
+
21
+ :param image: Image to be encoded.
22
+ :param order_: Order of encoding, default is ``CHW``.
23
+ :param use_float: Use float to represent the channels, default is ``True``. ``np.uint8`` will be used when false.
24
+ :return: Encoded rgb image.
25
+
26
+ Examples::
27
+ >>> from PIL import Image
28
+ >>> from encode import rgb_encode
29
+ >>>
30
+ >>> image = Image.open('custom_image.jpg')
31
+ >>> image
32
+ <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1606x1870 at 0x7F9EC37389D0>
33
+ >>>
34
+ >>> data = rgb_encode(image)
35
+ >>> data.shape, data.dtype
36
+ ((3, 1870, 1606), dtype('float32'))
37
+ >>> data = rgb_encode(image, order_='CHW')
38
+ >>> data.shape, data.dtype
39
+ ((3, 1870, 1606), dtype('float32'))
40
+ >>> data = rgb_encode(image, order_='WHC')
41
+ >>> data.shape, data.dtype
42
+ ((1606, 1870, 3), dtype('float32'))
43
+ >>> data = rgb_encode(image, use_float=False)
44
+ >>> data.shape, data.dtype
45
+ ((3, 1870, 1606), dtype('uint8'))
46
+
47
+ .. note::
48
+ The function :func:`rgb_encode`'s result is the same as \
49
+ ``torchvision.transforms.functional import to_tensor``'s result when the given ``image`` is in RGB mode.
50
+ """
51
+ image = load_image(image, mode='RGB')
52
+ array = np.asarray(image)
53
+ array = np.transpose(array, _get_hwc_map(order_))
54
+ if use_float:
55
+ array = (array / 255.0).astype(np.float32)
56
+ assert array.dtype == np.float32
57
+ else:
58
+ assert array.dtype == np.uint8
59
+ return array
monochrome.py CHANGED
@@ -1,24 +1,25 @@
1
  from functools import lru_cache
2
- from typing import Optional, Mapping
3
 
4
  import numpy as np
5
- from PIL import Image, ImageFilter
 
6
  from huggingface_hub import hf_hub_download
7
- from scipy import signal
8
 
 
9
  from image import ImageTyping, load_image
10
  from onnxruntime_ import open_onnx_model
11
 
12
  __all__ = [
13
  'get_monochrome_score',
 
14
  ]
15
 
16
- _DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
17
  _MONOCHROME_CKPTS = [
18
- 'monochrome-resnet18-safe2-450.onnx',
19
- 'monochrome-resnet18-480.onnx',
20
- 'monochrome-alexnet-480.onnx',
21
  ]
 
22
 
23
 
24
  @lru_cache()
@@ -29,49 +30,29 @@ def _monochrome_validate_model(ckpt):
29
  ))
30
 
31
 
32
- def np_hist(x, a_min: float = 0.0, a_max: float = 1.0, bins: int = 256):
33
- x = np.asarray(x)
34
- edges = np.linspace(a_min, a_max, bins + 1)
35
- cnt, _ = np.histogram(x, bins=edges)
36
- return cnt / cnt.sum()
 
37
 
 
 
 
 
 
38
 
39
- def butterworth_filter(r, fc):
40
- w = fc / (len(r) / 2) # Normalize the frequency
41
- b, a = signal.butter(5, w, 'low')
42
- return np.clip(signal.filtfilt(b, a, r), a_min=0.0, a_max=1.0)
43
 
44
 
45
- def _hsv_encode(image: Image.Image, feature_bins: int = 180, mf: Optional[int] = 5,
46
- maxpixels: int = 20000, fc: Optional[int] = 75, normalize: bool = True):
47
- if image.width * image.height > maxpixels:
48
- r = (image.width * image.height / maxpixels) ** 0.5
49
- new_width, new_height = map(lambda x: int(round(x / r)), image.size)
50
- image = image.resize((new_width, new_height))
51
-
52
- if mf is not None:
53
- image = image.filter(ImageFilter.MedianFilter(mf))
54
- image = image.convert('HSV')
55
-
56
- data = (np.transpose(np.asarray(image), (2, 0, 1)) / 255.0).astype(np.float32)
57
- channels = [np_hist(data[i], bins=feature_bins) for i in range(3)]
58
- if fc is not None:
59
- channels = [butterworth_filter(ch, fc) for ch in channels]
60
-
61
- dist = np.stack(channels)
62
- assert dist.shape == (3, feature_bins)
63
-
64
- if normalize:
65
- mean = np.mean(dist, axis=1, keepdims=True)
66
- std = np.std(dist, axis=1, keepdims=True, ddof=1)
67
- dist = (dist - mean) / std
68
-
69
- return dist
70
-
71
-
72
- def get_monochrome_score(image: ImageTyping, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> Mapping[str, float]:
73
  image = load_image(image, mode='RGB')
74
- input_data = _hsv_encode(image).astype(np.float32)
75
  input_data = np.stack([input_data])
76
  output_data, = _monochrome_validate_model(ckpt).run(['output'], {'input': input_data})
77
  return {name: v.item() for name, v in zip(['normal', 'monochrome'], output_data[0])}
 
 
 
 
 
1
  from functools import lru_cache
2
+ from typing import Optional, Tuple
3
 
4
  import numpy as np
5
+ from PIL import Image
6
+ from PIL.Image import Resampling
7
  from huggingface_hub import hf_hub_download
 
8
 
9
+ from encode import rgb_encode
10
  from image import ImageTyping, load_image
11
  from onnxruntime_ import open_onnx_model
12
 
13
  __all__ = [
14
  'get_monochrome_score',
15
+ 'is_monochrome',
16
  ]
17
 
18
+ # _DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
19
  _MONOCHROME_CKPTS = [
20
+ 'monochrome-caformer_safe2-80.onnx',
 
 
21
  ]
22
+ _DEFAULT_MONOCHROME_CKPT = _MONOCHROME_CKPTS[0]
23
 
24
 
25
  @lru_cache()
 
30
  ))
31
 
32
 
33
+ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
34
+ normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
35
+ if image.mode != 'RGB':
36
+ image = image.convert('RGB')
37
+ image = image.resize(size, Resampling.BILINEAR)
38
+ data = rgb_encode(image, order_='CHW')
39
 
40
+ if normalize is not None:
41
+ mean_, std_ = normalize
42
+ mean = np.asarray([mean_]).reshape((-1, 1, 1))
43
+ std = np.asarray([std_]).reshape((-1, 1, 1))
44
+ data = (data - mean) / std
45
 
46
+ return data
 
 
 
47
 
48
 
49
+ def get_monochrome_score(image: ImageTyping, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> float:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  image = load_image(image, mode='RGB')
51
+ input_data = _2d_encode(image).astype(np.float32)
52
  input_data = np.stack([input_data])
53
  output_data, = _monochrome_validate_model(ckpt).run(['output'], {'input': input_data})
54
  return {name: v.item() for name, v in zip(['normal', 'monochrome'], output_data[0])}
55
+
56
+
57
+ def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool:
58
+ return get_monochrome_score(image, ckpt) >= threshold