thng292 commited on
Commit
d0ac7e9
·
verified ·
1 Parent(s): 5bbf229

Upload 18 files

Browse files
.Dockerignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ facexformer
2
+ ckpts
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Exclude model's weights
2
+ *.pt
3
+ .venv/
4
+ .vscode/
5
+ **/__pycache__
6
+ data
7
+ saves
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Kartik Narayan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import streamlit as st
2
+ # from streamlit_webrtc import webrtc_streamer
3
+ # import torch
4
+ # torch.classes.__path__ = []
5
+
6
+ import sys
7
+ import os
8
+ from glob import glob
9
+ import gradio as gr
10
+ from fastrtc import WebRTC
11
+ from fastrtc import VideoStreamHandler
12
+ from PIL import Image
13
+ import landmark_detection
14
+ import numpy as np
15
+ from time import time
16
+ import cv2
17
+ from mtcnn_facedetection import detect_faces
18
+ from selfie_filter import apply_sunglasses, process_video
19
+
20
+
21
+ radius = 2
22
+ filter_img = None
23
+
24
+
25
+ def do_facial_landmark_recognition(
26
+ image: np.ndarray, face_boxes: list[landmark_detection.BoundingBox]
27
+ ):
28
+ faces = landmark_detection.get_faces(image, face_boxes)
29
+ landmarks_batch = landmark_detection.get_landmarks(faces)
30
+
31
+ for i, landmarks in enumerate(landmarks_batch):
32
+ for landmark in landmarks:
33
+ image = cv2.circle(image, landmark, radius, (255, 0, 0), -1)
34
+
35
+ return image, landmarks_batch
36
+
37
+
38
+ def do_facial_landmark_recognition_with_mtcnn(image: np.ndarray):
39
+ face_boxes = detect_faces(image)
40
+ return do_facial_landmark_recognition(image, face_boxes)
41
+
42
+
43
+ def video_frame_callback_gradio(frame: np.array):
44
+ flipped = cv2.flip(frame, 1)
45
+
46
+ flipped, landmarks_batch = do_facial_landmark_recognition_with_mtcnn(flipped)
47
+ # Apply sunglasses filter
48
+ image = apply_sunglasses(flipped, landmarks_batch, filter_img)
49
+
50
+ return image # , AdditionalOutputs(flipped, flipped)
51
+
52
+
53
+ css = """.my-group {max-width: 600px !important;}
54
+ .my-column {display: flex !important; justify-content: center !important; align-items: center !important;}"""
55
+
56
+ image_extensions = [
57
+ "*.jpg",
58
+ "*.jpeg",
59
+ "*.png",
60
+ "*.gif",
61
+ "*.bmp",
62
+ "*.tiff",
63
+ "*.webp",
64
+ ]
65
+ all_image_files = []
66
+
67
+ for ext in image_extensions:
68
+ pattern = os.path.join("images", "**", ext) # '**' for recursive search
69
+ image_files = glob(pattern, recursive=True)
70
+ all_image_files.extend(image_files)
71
+ all_image_files.sort()
72
+
73
+
74
+ with gr.Blocks(css=css) as demo:
75
+ with gr.Column(elem_classes=["my-column"]):
76
+ gr.HTML(
77
+ """
78
+ <h1 style='text-align: center'>
79
+ Live Filter with FaceXFormer
80
+ </h1>
81
+ """
82
+ )
83
+ with gr.Group(elem_classes=["my-group"]):
84
+ selected_filter = gr.Dropdown(
85
+ choices=all_image_files,
86
+ label="Choose filter",
87
+ value="images/sunglasses_1.png",
88
+ )
89
+
90
+ def change_filter(filter_path):
91
+ global filter_img
92
+ try:
93
+ filter_img = cv2.imread(filter_path, cv2.IMREAD_UNCHANGED)
94
+ except:
95
+ gr.Error("Error open" + filter_path)
96
+
97
+ change_filter(selected_filter.value)
98
+
99
+ selected_filter.change(
100
+ change_filter, inputs=[selected_filter], show_progress="full"
101
+ )
102
+
103
+ with gr.Group(elem_classes=["my-group"]):
104
+ stream = WebRTC(label="Stream", rtc_configuration=None)
105
+ stream.stream(
106
+ fn=VideoStreamHandler(
107
+ video_frame_callback_gradio, fps=12, skip_frames=True
108
+ ),
109
+ inputs=[stream],
110
+ outputs=[stream],
111
+ time_limit=None,
112
+ )
113
+
114
+ with gr.Group(elem_classes=["my-group"]):
115
+ with gr.Column(elem_classes=["my-column"]):
116
+ gr.HTML(
117
+ """
118
+ <h1 style='text-align: center'>
119
+ Or just apply the filter to a video
120
+ </h1>
121
+ """
122
+ )
123
+ input_video = gr.Video(sources="upload", include_audio=False)
124
+ output_video = gr.Video(interactive=False, include_audio=False)
125
+ submit = gr.Button(variant="primary")
126
+ with gr.Column(elem_classes=["my-column"]):
127
+ submit.click(
128
+ lambda input_path: process_video(input_path, filter_img),
129
+ inputs=[input_video],
130
+ outputs=[output_video],
131
+ show_progress="full",
132
+ )
133
+
134
+
135
+ def test(times=10):
136
+ image = np.array(Image.open("tmp.jpg").resize((512, 512)))
137
+ # faces = ai.get_faces(image)
138
+ start = time()
139
+ frame_times = [None] * times
140
+ for i in range(times):
141
+ before = time()
142
+ do_facial_landmark_recognition_with_mtcnn(image)
143
+ after = time()
144
+ frame_times[i] = after - before
145
+ end = time()
146
+
147
+ print(f"Num Images: {times}")
148
+ print(f"Total time: {end - start}")
149
+ print(
150
+ f"Max frametime: {max(frame_times)}, FPS: {1 / max(frame_times)}",
151
+ )
152
+ print(
153
+ f"Min frametime: {min(frame_times)}, FPS: {1 / min(frame_times)}",
154
+ )
155
+ print(
156
+ f"Avg frametime: {sum(frame_times) / len(frame_times)}, FPS: {1 / (sum(frame_times) / len(frame_times))}",
157
+ )
158
+
159
+
160
+ if __name__ == "__main__":
161
+ no_params = 0
162
+ for name, i in landmark_detection.model.named_parameters(recurse=True):
163
+ no_params += i.numel()
164
+ print(name, i.numel())
165
+
166
+ print(no_params)
167
+ if "--test" in sys.argv:
168
+ test()
169
+ exit(0)
170
+ else:
171
+ demo.launch()
images/sunglasses_1.png ADDED
images/sunglasses_2.png ADDED
images/sunglasses_3.jpg ADDED
images/sunglasses_4.png ADDED
images/sunglasses_5.jpg ADDED
images/sunglasses_6.png ADDED
landmark_detection.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torchvision.transforms import InterpolationMode
4
+ from network.models.facexformer import FaceXFormer
5
+ from dataclasses import dataclass
6
+ import numpy as np
7
+
8
+ # import mediapipe as mp
9
+ # import cv2
10
+
11
+
12
+ # device = "cuda:0"
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ dtype = torch.float32
15
+ # weights_path = "ckpts/model.pt"
16
+ weights_path = "ckpts/pytorch_model.bin"
17
+ # face_model_path = "ckpts/blaze_face_short_range.tflite"
18
+
19
+ # import mediapipe as mp
20
+
21
+ # BaseOptions = mp.tasks.BaseOptions
22
+ # FaceDetector = mp.tasks.vision.FaceDetector
23
+ # FaceDetectorOptions = mp.tasks.vision.FaceDetectorOptions
24
+ # FaceDetectorResult = mp.tasks.vision.FaceDetectorResult
25
+ # VisionRunningMode = mp.tasks.vision.RunningMode
26
+
27
+ # options = FaceDetectorOptions(
28
+ # base_options=BaseOptions(model_asset_path=face_model_path),
29
+ # running_mode=VisionRunningMode.LIVE_STREAM,
30
+ # )
31
+ # face_detector = FaceDetector.create_from_options(options)
32
+
33
+ transforms_image = torchvision.transforms.Compose(
34
+ [
35
+ torchvision.transforms.ToPILImage(),
36
+ torchvision.transforms.Resize(
37
+ size=(224, 224), interpolation=InterpolationMode.BICUBIC
38
+ ),
39
+ torchvision.transforms.ToTensor(),
40
+ torchvision.transforms.Normalize(
41
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
42
+ ),
43
+ ]
44
+ )
45
+
46
+
47
+ def load_model(weights_path):
48
+ model = FaceXFormer().to(device)
49
+ checkpoint = torch.load(weights_path, map_location=device)
50
+ model.load_state_dict(checkpoint)
51
+ # model.load_state_dict(checkpoint["state_dict_backbone"])
52
+ model = model.eval()
53
+ model = model.to(dtype=dtype)
54
+ # model = torch.compile(model, mode="reduce-overhead")
55
+ return model
56
+
57
+
58
+ model = load_model(weights_path)
59
+
60
+
61
+ def adjust_bbox(
62
+ x_min, y_min, x_max, y_max, image_width, image_height, margin_percentage=50
63
+ ):
64
+ width = x_max - x_min
65
+ height = y_max - y_min
66
+
67
+ increase_width = width * (margin_percentage / 100.0) / 2
68
+ increase_height = height * (margin_percentage / 100.0) / 2
69
+
70
+ x_min_adjusted = int(max(0, x_min - increase_width))
71
+ y_min_adjusted = int(max(0, y_min - increase_height))
72
+ x_max_adjusted = int(min(image_width, x_max + increase_width))
73
+ y_max_adjusted = int(min(image_height, y_max + increase_height))
74
+
75
+ return x_min_adjusted, y_min_adjusted, x_max_adjusted, y_max_adjusted
76
+
77
+
78
+ def denorm_points(points, h, w, align_corners=False):
79
+ if align_corners:
80
+ denorm_points = (
81
+ (points + 1) / 2 * torch.tensor([w - 1, h - 1]).to(points).view(1, 1, 2)
82
+ )
83
+ else:
84
+ denorm_points = (
85
+ (points + 1) * torch.tensor([w, h]).to(points).view(1, 1, 2) - 1
86
+ ) / 2
87
+
88
+ return denorm_points
89
+
90
+
91
+ @dataclass
92
+ class BoundingBox:
93
+ x_min: int
94
+ y_min: int
95
+ x_max: int
96
+ y_max: int
97
+
98
+
99
+ @dataclass
100
+ class FaceImg:
101
+ image: np.ndarray
102
+ x_min: int
103
+ y_min: int
104
+
105
+
106
+ def get_faces_img(img: np.ndarray, boxes: list[BoundingBox]):
107
+ if boxes is None or len(boxes) == 0:
108
+ return []
109
+ results: list[FaceImg] = []
110
+ for box in boxes:
111
+ x_min, y_min, x_max, y_max = box.x_min, box.y_min, box.x_max, box.y_max
112
+
113
+ # Padding
114
+ x_min, y_min, x_max, y_max = adjust_bbox(
115
+ x_min, y_min, x_max, y_max, img.shape[1], img.shape[0]
116
+ )
117
+ image = img[y_min:y_max, x_min:x_max]
118
+ results.append(FaceImg(image, int(x_min), int(y_min)))
119
+
120
+ return results
121
+
122
+
123
+ @dataclass
124
+ class Face:
125
+ image: torch.Tensor
126
+ x_min: int
127
+ y_min: int
128
+ original_w: int
129
+ original_h: int
130
+
131
+
132
+ def get_faces(img: np.ndarray, boxes: list[BoundingBox]):
133
+ images = get_faces_img(img, boxes)
134
+ images = [
135
+ Face(
136
+ transforms_image(face_image.image),
137
+ face_image.x_min,
138
+ face_image.y_min,
139
+ face_image.image.shape[1],
140
+ face_image.image.shape[0],
141
+ )
142
+ for face_image in images
143
+ ]
144
+ return images
145
+
146
+
147
+ def get_landmarks(faces: list[Face]):
148
+ if len(faces) == 0:
149
+ return []
150
+
151
+ images = torch.stack([face.image for face in faces]).to(device=device, dtype=dtype)
152
+
153
+ tasks = torch.tensor([1] * len(faces), device=device, dtype=dtype)
154
+ with torch.inference_mode():
155
+ # with torch.amp.autocast("cuda"):
156
+ (
157
+ batch_landmarks,
158
+ headposes,
159
+ attributes,
160
+ visibilities,
161
+ ages,
162
+ geders,
163
+ races,
164
+ segs,
165
+ ) = model.predict(images, None, tasks)
166
+ batch_denormed = [
167
+ denorm_points(landmarks, face.original_h, face.original_w)[0]
168
+ for landmarks, face in zip(batch_landmarks.view(-1, 68, 2), faces)
169
+ ]
170
+
171
+ results = []
172
+ for landmarks, face in zip(batch_denormed, faces):
173
+ results.append(
174
+ [(int(x + face.x_min), int(y + face.y_min)) for x, y in landmarks]
175
+ )
176
+
177
+ return results
mediapipe_facedetection.py ADDED
File without changes
mtcnn_facedetection.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from landmark_detection import device, BoundingBox
2
+ from facenet_pytorch import MTCNN
3
+ import numpy as np
4
+
5
+ mtcnn = MTCNN(keep_all=True, device=device).eval()
6
+
7
+
8
+ def detect_faces(img) -> list[BoundingBox]:
9
+ boxes, probs = mtcnn.detect(img)
10
+ return [
11
+ BoundingBox(
12
+ x_min=int(box[0]),
13
+ y_min=int(box[1]),
14
+ x_max=int(box[2]),
15
+ y_max=int(box[3]),
16
+ )
17
+ for box in boxes
18
+ ] if boxes is not None else []
network/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .models import FaceXFormer
network/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transformer import TwoWayTransformer, LayerNorm2d
2
+ from .facexformer import FaceXFormer
network/models/facexformer.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision.models as models
6
+ from typing import Any, Optional, Tuple, Type
7
+ from torchvision.models import swin_b, convnext_base
8
+ from .transformer import TwoWayTransformer, LayerNorm2d
9
+ from transformers.utils.generic import ModelOutput
10
+
11
+
12
+ class MLP(nn.Module):
13
+ def __init__(
14
+ self,
15
+ input_dim: int,
16
+ hidden_dim: int,
17
+ output_dim: int,
18
+ num_layers: int,
19
+ sigmoid_output: bool = False,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.num_layers = num_layers
23
+ h = [hidden_dim] * (num_layers - 1)
24
+ self.layers = nn.ModuleList(
25
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
26
+ )
27
+ self.sigmoid_output = sigmoid_output
28
+
29
+ def forward(self, x):
30
+ for i, layer in enumerate(self.layers):
31
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
32
+ if self.sigmoid_output:
33
+ x = F.sigmoid(x)
34
+ return x
35
+
36
+
37
+ class FaceDecoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ *,
41
+ transformer_dim: 256,
42
+ transformer: nn.Module,
43
+ activation: Type[nn.Module] = nn.GELU,
44
+ ) -> None:
45
+
46
+ super().__init__()
47
+ self.transformer_dim = transformer_dim
48
+ self.transformer = transformer
49
+
50
+ self.landmarks_token = nn.Embedding(1, transformer_dim)
51
+ self.pose_token = nn.Embedding(1, transformer_dim)
52
+ self.attribute_token = nn.Embedding(1, transformer_dim)
53
+ self.visibility_token = nn.Embedding(1, transformer_dim)
54
+ self.age_token = nn.Embedding(1, transformer_dim)
55
+ self.gender_token = nn.Embedding(1, transformer_dim)
56
+ self.race_token = nn.Embedding(1, transformer_dim)
57
+ self.mask_tokens = nn.Embedding(11, transformer_dim)
58
+
59
+ self.output_upscaling = nn.Sequential(
60
+ nn.ConvTranspose2d(
61
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
62
+ ),
63
+ LayerNorm2d(transformer_dim // 4),
64
+ activation(),
65
+ nn.ConvTranspose2d(
66
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
67
+ ),
68
+ activation(),
69
+ )
70
+
71
+ self.output_hypernetwork_mlps = MLP(
72
+ transformer_dim, transformer_dim, transformer_dim // 8, 3
73
+ )
74
+
75
+ self.landmarks_prediction_head = MLP(transformer_dim, transformer_dim, 136, 3)
76
+ self.pose_prediction_head = MLP(transformer_dim, transformer_dim, 3, 3)
77
+ self.attribute_prediction_head = MLP(transformer_dim, transformer_dim, 40, 3)
78
+ self.visibility_prediction_head = MLP(transformer_dim, transformer_dim, 29, 3)
79
+ self.age_prediction_head = MLP(transformer_dim, transformer_dim, 8, 3)
80
+ self.gender_prediction_head = MLP(transformer_dim, transformer_dim, 2, 3)
81
+ self.race_prediction_head = MLP(transformer_dim, transformer_dim, 5, 3)
82
+
83
+ def forward(
84
+ self,
85
+ image_embeddings: torch.Tensor,
86
+ image_pe: torch.Tensor,
87
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ output_tokens = torch.cat(
89
+ [
90
+ self.landmarks_token.weight,
91
+ self.pose_token.weight,
92
+ self.attribute_token.weight,
93
+ self.visibility_token.weight,
94
+ self.age_token.weight,
95
+ self.gender_token.weight,
96
+ self.race_token.weight,
97
+ self.mask_tokens.weight,
98
+ ],
99
+ dim=0,
100
+ )
101
+ tokens = output_tokens.unsqueeze(0).expand(image_embeddings.size(0), -1, -1)
102
+
103
+ src = image_embeddings
104
+ pos_src = image_pe.expand(image_embeddings.size(0), -1, -1, -1)
105
+ b, c, h, w = src.shape
106
+
107
+ hs, src = self.transformer(src, pos_src, tokens)
108
+
109
+ landmarks_token_out = hs[:, 0, :]
110
+ pose_token_out = hs[:, 1, :]
111
+ attribute_token_out = hs[:, 2, :]
112
+ visibility_token_out = hs[:, 3, :]
113
+ age_token_out = hs[:, 4, :]
114
+ gender_token_out = hs[:, 5, :]
115
+ race_token_out = hs[:, 6, :]
116
+ mask_token_out = hs[:, 7:, :]
117
+
118
+ landmark_output = self.landmarks_prediction_head(landmarks_token_out)
119
+ headpose_output = self.pose_prediction_head(pose_token_out)
120
+ attribute_output = self.attribute_prediction_head(attribute_token_out)
121
+ visibility_output = self.visibility_prediction_head(visibility_token_out)
122
+ age_output = self.age_prediction_head(age_token_out)
123
+ gender_output = self.gender_prediction_head(gender_token_out)
124
+ race_output = self.race_prediction_head(race_token_out)
125
+
126
+ src = src.transpose(1, 2).view(b, c, h, w)
127
+ upscaled_embedding = self.output_upscaling(src)
128
+ hyper_in = self.output_hypernetwork_mlps(mask_token_out)
129
+ b, c, h, w = upscaled_embedding.shape
130
+ seg_output = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
131
+
132
+ return (
133
+ landmark_output,
134
+ headpose_output,
135
+ attribute_output,
136
+ visibility_output,
137
+ age_output,
138
+ gender_output,
139
+ race_output,
140
+ seg_output,
141
+ )
142
+
143
+
144
+ class PositionEmbeddingRandom(nn.Module):
145
+ """
146
+ Positional encoding using random spatial frequencies.
147
+ """
148
+
149
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
150
+ super().__init__()
151
+ if scale is None or scale <= 0.0:
152
+ scale = 1.0
153
+ self.register_buffer(
154
+ "positional_encoding_gaussian_matrix",
155
+ scale * torch.randn((2, num_pos_feats)),
156
+ )
157
+
158
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
159
+ """Positionally encode points that are normalized to [0,1]."""
160
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
161
+ coords = 2 * coords - 1
162
+ coords = coords @ self.positional_encoding_gaussian_matrix
163
+ coords = 2 * np.pi * coords
164
+ # outputs d_1 x ... x d_n x C shape
165
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
166
+
167
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
168
+ """Generate positional encoding for a grid of the specified size."""
169
+ h, w = size
170
+ device: Any = self.positional_encoding_gaussian_matrix.device
171
+ grid = torch.ones((h, w), device=device, dtype=torch.float32)
172
+ y_embed = grid.cumsum(dim=0) - 0.5
173
+ x_embed = grid.cumsum(dim=1) - 0.5
174
+ y_embed = y_embed / h
175
+ x_embed = x_embed / w
176
+
177
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
178
+ return pe.permute(2, 0, 1) # C x H x W
179
+
180
+ def forward_with_coords(
181
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
182
+ ) -> torch.Tensor:
183
+ """Positionally encode points that are not normalized to [0,1]."""
184
+ coords = coords_input.clone()
185
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
186
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
187
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
188
+
189
+
190
+ class FaceXFormerMLP(nn.Module):
191
+ def __init__(self, input_dim):
192
+ super().__init__()
193
+ self.proj = nn.Linear(input_dim, 256) # 128, 256, 512, 1024 => 256
194
+
195
+ def forward(self, hidden_states: torch.Tensor):
196
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
197
+ hidden_states = self.proj(hidden_states)
198
+ return hidden_states
199
+
200
+
201
+ class FaceXFormer(nn.Module):
202
+ def __init__(self):
203
+ super(FaceXFormer, self).__init__()
204
+
205
+ # Backbone: Swin-B
206
+ swin_v2 = swin_b(weights="IMAGENET1K_V1")
207
+ self.backbone = torch.nn.Sequential(*(list(swin_v2.children())[:-1]))
208
+ self.backbone.requires_grad_(False)
209
+
210
+ # # Backbone: ConvNext-B
211
+ # convnext_v2 = convnext_base(weights='IMAGENET1K_V1')
212
+ # self.backbone = torch.nn.Sequential(
213
+ # *(list(convnext_v2.children())[:-1]))
214
+
215
+ self.target_layer_names = ["0.1", "0.3", "0.5", "0.7"]
216
+ self.multi_scale_features = []
217
+
218
+ embed_dim = 1024
219
+ out_chans = 256
220
+
221
+ self.pe_layer = PositionEmbeddingRandom(out_chans // 2)
222
+
223
+ for name, module in self.backbone.named_modules():
224
+ if name in self.target_layer_names:
225
+ module.register_forward_hook(self.save_features_hook(name))
226
+
227
+ self.face_decoder = FaceDecoder(
228
+ transformer_dim=256,
229
+ transformer=TwoWayTransformer(
230
+ depth=2,
231
+ embedding_dim=256,
232
+ mlp_dim=2048,
233
+ num_heads=8,
234
+ ),
235
+ )
236
+
237
+ num_encoder_blocks = 4
238
+ hidden_sizes = [128, 256, 512, 1024]
239
+ decoder_hidden_size = 256
240
+
241
+ mlps = []
242
+ for i in range(num_encoder_blocks):
243
+ mlp = FaceXFormerMLP(input_dim=hidden_sizes[i])
244
+ mlps.append(mlp)
245
+ self.linear_c = nn.ModuleList(mlps)
246
+
247
+ self.linear_fuse = nn.Conv2d(
248
+ in_channels=decoder_hidden_size * num_encoder_blocks, # 1024
249
+ out_channels=decoder_hidden_size, # 256
250
+ kernel_size=1,
251
+ bias=False,
252
+ )
253
+
254
+ def save_features_hook(self, name):
255
+ def hook(module, input, output):
256
+ self.multi_scale_features.append(output.permute(0, 3, 1, 2).contiguous())
257
+
258
+ return hook
259
+
260
+ def predict(self, x, labels, tasks):
261
+ self.multi_scale_features.clear()
262
+
263
+ _, _, h, w = x.shape
264
+ features = self.backbone(x).squeeze()
265
+
266
+ batch_size = self.multi_scale_features[-1].shape[0]
267
+ all_hidden_states = ()
268
+ for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c):
269
+
270
+ height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
271
+ encoder_hidden_state = mlp(encoder_hidden_state)
272
+ encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
273
+ encoder_hidden_state = encoder_hidden_state.reshape(
274
+ batch_size, -1, height, width
275
+ )
276
+ encoder_hidden_state = nn.functional.interpolate(
277
+ encoder_hidden_state,
278
+ size=self.multi_scale_features[0].size()[2:],
279
+ mode="bilinear",
280
+ align_corners=False,
281
+ )
282
+ all_hidden_states += (encoder_hidden_state,)
283
+
284
+ fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
285
+ image_pe = self.pe_layer(
286
+ (fused_states.shape[2], fused_states.shape[3])
287
+ ).unsqueeze(0)
288
+
289
+ (
290
+ landmark_output,
291
+ headpose_output,
292
+ attribute_output,
293
+ visibility_output,
294
+ age_output,
295
+ gender_output,
296
+ race_output,
297
+ seg_output,
298
+ ) = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe)
299
+
300
+ segmentation_indices = tasks == 0
301
+ seg_output = seg_output[segmentation_indices]
302
+
303
+ landmarks_indices = tasks == 1
304
+ landmark_output = landmark_output[landmarks_indices]
305
+
306
+ headpose_indices = tasks == 2
307
+ headpose_output = headpose_output[headpose_indices]
308
+
309
+ attribute_indices = tasks == 3
310
+ attribute_output = attribute_output[attribute_indices]
311
+
312
+ age_indices = tasks == 4
313
+ age_output = age_output[age_indices]
314
+ gender_output = gender_output[age_indices]
315
+ race_output = race_output[age_indices]
316
+
317
+ visibility_indices = tasks == 5
318
+ visibility_output = visibility_output[visibility_indices]
319
+
320
+ return (
321
+ landmark_output,
322
+ headpose_output,
323
+ attribute_output,
324
+ visibility_output,
325
+ age_output,
326
+ gender_output,
327
+ race_output,
328
+ seg_output,
329
+ )
330
+
331
+ def loss(
332
+ self, predictions: torch.Tensor, labels: torch.Tensor, num_items_in_batch=None
333
+ ):
334
+ # print(predictions.shape)
335
+ # print(labels.shape)
336
+ # print("predic:", predictions)
337
+ # print("labels:", labels)
338
+ # Used L2 loss for now
339
+ loss = torch.nn.functional.mse_loss(predictions, labels, reduction="sum")
340
+ if num_items_in_batch:
341
+ loss /= num_items_in_batch
342
+ return loss
343
+
344
+ def forward(self, x, labels, num_items_in_batch=None):
345
+ self.multi_scale_features.clear()
346
+
347
+ _, _, h, w = x.shape
348
+ features = self.backbone(x).squeeze()
349
+
350
+ batch_size = self.multi_scale_features[-1].shape[0]
351
+ all_hidden_states = ()
352
+ for encoder_hidden_state, mlp in zip(self.multi_scale_features, self.linear_c):
353
+
354
+ height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
355
+ encoder_hidden_state = mlp(encoder_hidden_state)
356
+ encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1)
357
+ encoder_hidden_state = encoder_hidden_state.reshape(
358
+ batch_size, -1, height, width
359
+ )
360
+ encoder_hidden_state = nn.functional.interpolate(
361
+ encoder_hidden_state,
362
+ size=self.multi_scale_features[0].size()[2:],
363
+ mode="bilinear",
364
+ align_corners=False,
365
+ )
366
+ all_hidden_states += (encoder_hidden_state,)
367
+
368
+ fused_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1))
369
+ image_pe = self.pe_layer(
370
+ (fused_states.shape[2], fused_states.shape[3])
371
+ ).unsqueeze(0)
372
+
373
+ (
374
+ landmark_output,
375
+ headpose_output,
376
+ attribute_output,
377
+ visibility_output,
378
+ age_output,
379
+ gender_output,
380
+ race_output,
381
+ seg_output,
382
+ ) = self.face_decoder(image_embeddings=fused_states, image_pe=image_pe)
383
+
384
+ # All tasks are landmark prediction
385
+ if labels is not None:
386
+ loss = self.loss(landmark_output.view(-1, 68, 2), labels)
387
+ else:
388
+ loss = None
389
+
390
+ return ModelOutput(
391
+ loss=loss,
392
+ )
network/models/transformer.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ import math
11
+ from typing import Tuple, Type
12
+
13
+
14
+ class MLPBlock(nn.Module):
15
+ def __init__(
16
+ self,
17
+ embedding_dim: int,
18
+ mlp_dim: int,
19
+ act: Type[nn.Module] = nn.GELU,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
23
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
24
+ self.act = act()
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ return self.lin2(self.act(self.lin1(x)))
28
+
29
+
30
+ # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
31
+ # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
32
+ class LayerNorm2d(nn.Module):
33
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
34
+ super().__init__()
35
+ self.weight = nn.Parameter(torch.ones(num_channels))
36
+ self.bias = nn.Parameter(torch.zeros(num_channels))
37
+ self.eps = eps
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ u = x.mean(1, keepdim=True)
41
+ s = (x - u).pow(2).mean(1, keepdim=True)
42
+ x = (x - u) / torch.sqrt(s + self.eps)
43
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
44
+ return x
45
+
46
+
47
+ class TwoWayTransformer(nn.Module):
48
+ def __init__(
49
+ self,
50
+ depth: int,
51
+ embedding_dim: int,
52
+ num_heads: int,
53
+ mlp_dim: int,
54
+ activation: Type[nn.Module] = nn.ReLU,
55
+ attention_downsample_rate: int = 2,
56
+ ) -> None:
57
+ """
58
+ A transformer decoder that attends to an input image using
59
+ queries whose positional embedding is supplied.
60
+
61
+ Args:
62
+ depth (int): number of layers in the transformer
63
+ embedding_dim (int): the channel dimension for the input embeddings
64
+ num_heads (int): the number of heads for multihead attention. Must
65
+ divide embedding_dim
66
+ mlp_dim (int): the channel dimension internal to the MLP block
67
+ activation (nn.Module): the activation to use in the MLP block
68
+ """
69
+ super().__init__()
70
+ self.depth = depth
71
+ self.embedding_dim = embedding_dim
72
+ self.num_heads = num_heads
73
+ self.mlp_dim = mlp_dim
74
+ self.layers = nn.ModuleList()
75
+
76
+ for i in range(depth):
77
+ self.layers.append(
78
+ TwoWayAttentionBlock(
79
+ embedding_dim=embedding_dim,
80
+ num_heads=num_heads,
81
+ mlp_dim=mlp_dim,
82
+ activation=activation,
83
+ attention_downsample_rate=attention_downsample_rate,
84
+ skip_first_layer_pe=(i == 0),
85
+ )
86
+ )
87
+
88
+ self.final_attn_token_to_image = Attention(
89
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
90
+ )
91
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
92
+
93
+ def forward(
94
+ self,
95
+ image_embedding: Tensor,
96
+ image_pe: Tensor,
97
+ point_embedding: Tensor,
98
+ ) -> Tuple[Tensor, Tensor]:
99
+ """
100
+ Args:
101
+ image_embedding (torch.Tensor): image to attend to. Should be shape
102
+ B x embedding_dim x h x w for any h and w.
103
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
104
+ have the same shape as image_embedding.
105
+ point_embedding (torch.Tensor): the embedding to add to the query points.
106
+ Must have shape B x N_points x embedding_dim for any N_points.
107
+
108
+ Returns:
109
+ torch.Tensor: the processed point_embedding
110
+ torch.Tensor: the processed image_embedding
111
+ """
112
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
113
+ bs, c, h, w = image_embedding.shape
114
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
115
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
116
+
117
+ # Prepare queries
118
+ queries = point_embedding
119
+ keys = image_embedding
120
+
121
+ # Apply transformer blocks and final layernorm
122
+ for layer in self.layers:
123
+ queries, keys = layer(
124
+ queries=queries,
125
+ keys=keys,
126
+ query_pe=point_embedding,
127
+ key_pe=image_pe,
128
+ )
129
+
130
+ # Apply the final attention layer from the points to the image
131
+ q = queries + point_embedding
132
+ k = keys + image_pe
133
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
134
+ queries = queries + attn_out
135
+ queries = self.norm_final_attn(queries)
136
+
137
+ return queries, keys
138
+
139
+
140
+ class TwoWayAttentionBlock(nn.Module):
141
+ def __init__(
142
+ self,
143
+ embedding_dim: int,
144
+ num_heads: int,
145
+ mlp_dim: int = 2048,
146
+ activation: Type[nn.Module] = nn.ReLU,
147
+ attention_downsample_rate: int = 2,
148
+ skip_first_layer_pe: bool = False,
149
+ ) -> None:
150
+ """
151
+ A transformer block with four layers: (1) self-attention of sparse
152
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
153
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
154
+ inputs.
155
+
156
+ Arguments:
157
+ embedding_dim (int): the channel dimension of the embeddings
158
+ num_heads (int): the number of heads in the attention layers
159
+ mlp_dim (int): the hidden dimension of the mlp block
160
+ activation (nn.Module): the activation of the mlp block
161
+ skip_first_layer_pe (bool): skip the PE on the first layer
162
+ """
163
+ super().__init__()
164
+ self.self_attn = Attention(embedding_dim, num_heads)
165
+ self.norm1 = nn.LayerNorm(embedding_dim)
166
+
167
+ self.cross_attn_token_to_image = Attention(
168
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
169
+ )
170
+ self.norm2 = nn.LayerNorm(embedding_dim)
171
+
172
+ self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
173
+ self.norm3 = nn.LayerNorm(embedding_dim)
174
+
175
+ self.norm4 = nn.LayerNorm(embedding_dim)
176
+ self.cross_attn_image_to_token = Attention(
177
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
178
+ )
179
+
180
+ self.skip_first_layer_pe = skip_first_layer_pe
181
+
182
+ def forward(
183
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
184
+ ) -> Tuple[Tensor, Tensor]:
185
+ # Self attention block
186
+ if self.skip_first_layer_pe:
187
+ queries = self.self_attn(q=queries, k=queries, v=queries)
188
+ else:
189
+ q = queries + query_pe
190
+ attn_out = self.self_attn(q=q, k=q, v=queries)
191
+ queries = queries + attn_out
192
+ queries = self.norm1(queries)
193
+
194
+ # Cross attention block, tokens attending to image embedding
195
+ q = queries + query_pe
196
+ k = keys + key_pe
197
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
198
+ queries = queries + attn_out
199
+ queries = self.norm2(queries)
200
+
201
+ # MLP block
202
+ mlp_out = self.mlp(queries)
203
+ queries = queries + mlp_out
204
+ queries = self.norm3(queries)
205
+
206
+ # Cross attention block, image embedding attending to tokens
207
+ q = queries + query_pe
208
+ k = keys + key_pe
209
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
210
+ keys = keys + attn_out
211
+ keys = self.norm4(keys)
212
+
213
+ return queries, keys
214
+
215
+
216
+ class Attention(nn.Module):
217
+ """
218
+ An attention layer that allows for downscaling the size of the embedding
219
+ after projection to queries, keys, and values.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ embedding_dim: int,
225
+ num_heads: int,
226
+ downsample_rate: int = 1,
227
+ ) -> None:
228
+ super().__init__()
229
+ self.embedding_dim = embedding_dim
230
+ self.internal_dim = embedding_dim // downsample_rate
231
+ self.num_heads = num_heads
232
+ assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
233
+
234
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
235
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
236
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
237
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
238
+
239
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
240
+ b, n, c = x.shape
241
+ x = x.reshape(b, n, num_heads, c // num_heads)
242
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
243
+
244
+ def _recombine_heads(self, x: Tensor) -> Tensor:
245
+ b, n_heads, n_tokens, c_per_head = x.shape
246
+ x = x.transpose(1, 2)
247
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
248
+
249
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
250
+ # Input projections
251
+ q = self.q_proj(q)
252
+ k = self.k_proj(k)
253
+ v = self.v_proj(v)
254
+
255
+ # Separate into heads
256
+ q = self._separate_heads(q, self.num_heads)
257
+ k = self._separate_heads(k, self.num_heads)
258
+ v = self._separate_heads(v, self.num_heads)
259
+
260
+ # Attention
261
+ _, _, _, c_per_head = q.shape
262
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
263
+ attn = attn / math.sqrt(c_per_head)
264
+ attn = torch.softmax(attn, dim=-1)
265
+
266
+ # Get output
267
+ out = attn @ v
268
+ out = self._recombine_heads(out)
269
+ out = self.out_proj(out)
270
+
271
+ return out
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ torchvision
4
+ git+https://github.com/thng292/facenet-pytorch.git
5
+ gradio
6
+ fastrtc
7
+ streamlit
8
+ streamlit-webrtc
9
+ opencv-python
10
+ huggingface_hub[cli]
11
+ transformers[torch]
12
+ datasets
13
+ mediapipe
14
+ deepspeed