narugo1992 commited on
Commit
a0ff18d
·
1 Parent(s): da78f71

dev(narugo): use new models

Browse files
Files changed (1) hide show
  1. monochrome.py +9 -10
monochrome.py CHANGED
@@ -17,19 +17,18 @@ __all__ = [
17
 
18
  # _DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
19
  _MONOCHROME_CKPTS = [
20
- 'monochrome-caformer-110.onnx',
21
- 'monochrome-caformer_safe2-80.onnx',
22
- 'monochrome-caformer_safe4-70.onnx',
23
- 'monochrome-caformer-40.onnx',
24
  ]
25
  _DEFAULT_MONOCHROME_CKPT = _MONOCHROME_CKPTS[0]
26
 
27
 
28
  @lru_cache()
29
- def _monochrome_validate_model(ckpt):
30
  return open_onnx_model(hf_hub_download(
31
- 'deepghs/imgutils-models',
32
- f'monochrome/{ckpt}'
33
  ))
34
 
35
 
@@ -49,12 +48,12 @@ def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
49
  return data
50
 
51
 
52
- def get_monochrome_score(image: ImageTyping, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> float:
53
  image = load_image(image, mode='RGB')
54
  input_data = _2d_encode(image).astype(np.float32)
55
  input_data = np.stack([input_data])
56
- output_data, = _monochrome_validate_model(ckpt).run(['output'], {'input': input_data})
57
- return {name: v.item() for name, v in zip(['normal', 'monochrome'], output_data[0])}
58
 
59
 
60
  def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool:
 
17
 
18
  # _DEFAULT_MONOCHROME_CKPT = 'monochrome-resnet18-safe2-450.onnx'
19
  _MONOCHROME_CKPTS = [
20
+ 'mobilenetv3_large_100_safe2',
21
+ 'mobilenetv3_large_100',
22
+ 'caformer_s36',
 
23
  ]
24
  _DEFAULT_MONOCHROME_CKPT = _MONOCHROME_CKPTS[0]
25
 
26
 
27
  @lru_cache()
28
+ def _monochrome_validate_model(model):
29
  return open_onnx_model(hf_hub_download(
30
+ f'deepghs/monochrome_detect',
31
+ f'{model}/model.onnx'
32
  ))
33
 
34
 
 
48
  return data
49
 
50
 
51
+ def get_monochrome_score(image: ImageTyping, model: str = _DEFAULT_MONOCHROME_CKPT):
52
  image = load_image(image, mode='RGB')
53
  input_data = _2d_encode(image).astype(np.float32)
54
  input_data = np.stack([input_data])
55
+ output_data, = _monochrome_validate_model(model).run(['output'], {'input': input_data})
56
+ return {name: v.item() for name, v in zip(['monochrome', 'normal'], output_data[0])}
57
 
58
 
59
  def is_monochrome(image: ImageTyping, threshold: float = 0.5, ckpt: str = _DEFAULT_MONOCHROME_CKPT) -> bool: