hysts HF Staff commited on
Commit
9553b1d
·
1 Parent(s): f8f6bc9
Files changed (2) hide show
  1. app.py +106 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import mim
4
+
5
+ mim.uninstall('mmcv-full', confirm_yes=True)
6
+ mim.install('mmcv-full', is_yes=True)
7
+
8
+ import functools
9
+ import pathlib
10
+
11
+ import cv2
12
+ import gradio as gr
13
+ import numpy as np
14
+ import PIL.Image
15
+ import torch
16
+
17
+ import anime_face_detector
18
+
19
+
20
+ def detect(img, face_score_threshold: float, landmark_score_threshold: float,
21
+ detector: anime_face_detector.LandmarkDetector) -> PIL.Image.Image:
22
+ image = cv2.imread(img.name)
23
+ preds = detector(image)
24
+
25
+ res = image.copy()
26
+ for pred in preds:
27
+ box = pred['bbox']
28
+ box, score = box[:4], box[4]
29
+ if score < face_score_threshold:
30
+ continue
31
+ box = np.round(box).astype(int)
32
+
33
+ lt = max(2, int(3 * (box[2:] - box[:2]).max() / 256))
34
+
35
+ cv2.rectangle(res, tuple(box[:2]), tuple(box[2:]), (0, 255, 0), lt)
36
+
37
+ pred_pts = pred['keypoints']
38
+ for *pt, score in pred_pts:
39
+ if score < landmark_score_threshold:
40
+ color = (0, 255, 255)
41
+ else:
42
+ color = (0, 0, 255)
43
+ pt = np.round(pt).astype(int)
44
+ cv2.circle(res, tuple(pt), lt, color, cv2.FILLED)
45
+ res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
46
+
47
+ image_pil = PIL.Image.fromarray(res)
48
+ return image_pil
49
+
50
+
51
+ def main():
52
+ sample_path = pathlib.Path('input.jpg')
53
+ if not sample_path.exists():
54
+ torch.hub.download_url_to_file(
55
+ 'https://raw.githubusercontent.com/hysts/anime-face-detector/main/assets/input.jpg',
56
+ sample_path.as_posix())
57
+
58
+ detector_name = 'yolov3'
59
+ device = 'cpu'
60
+ score_slider_step = 0.05
61
+ face_score_threshold = 0.5
62
+ landmark_score_threshold = 0.3
63
+ live = False
64
+
65
+ detector = anime_face_detector.create_detector(detector_name,
66
+ device=device)
67
+ func = functools.partial(detect, detector=detector)
68
+ func = functools.update_wrapper(func, detect)
69
+
70
+ title = 'hysts/anime-face-detector'
71
+ description = 'Demo for hysts/anime-face-detector. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.'
72
+ article = "<a href='https://github.com/hysts/anime-face-detector'>GitHub Repo</a>"
73
+
74
+ gr.Interface(
75
+ func,
76
+ [
77
+ gr.inputs.Image(type='file', label='Input'),
78
+ gr.inputs.Slider(0,
79
+ 1,
80
+ step=score_slider_step,
81
+ default=face_score_threshold,
82
+ label='Face Score Threshold'),
83
+ gr.inputs.Slider(0,
84
+ 1,
85
+ step=score_slider_step,
86
+ default=landmark_score_threshold,
87
+ label='Landmark Score Threshold'),
88
+ ],
89
+ gr.outputs.Image(type='pil', label='Output'),
90
+ title=title,
91
+ description=description,
92
+ article=article,
93
+ examples=[
94
+ [
95
+ sample_path.as_posix(),
96
+ face_score_threshold,
97
+ landmark_score_threshold,
98
+ ],
99
+ ],
100
+ enable_queue=True,
101
+ live=live,
102
+ ).launch()
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ anime-face-detector>=0.0.4
2
+ mmcv-full>=1.3.16
3
+ mmdet>=2.18.0
4
+ mmpose>=0.20.0
5
+ numpy>=1.21.3
6
+ opencv-python-headless>=4.5.4.58
7
+ openmim>=0.1.5
8
+ torch>=1.10.0
9
+ torchvision>=0.11.1