leonelhs commited on
Commit
75b33dc
Β·
1 Parent(s): f2d7a5e

source bisnet is added internally

Browse files
Files changed (7) hide show
  1. .gitignore +3 -0
  2. README.md +3 -4
  3. app.py +54 -29
  4. bisnet/__init__.py +39 -0
  5. bisnet/model.py +321 -0
  6. bisnet/resnet.py +145 -0
  7. requirements.txt +7 -5
.gitignore CHANGED
@@ -1,2 +1,5 @@
1
  .idea/
2
  __pycache__/
 
 
 
 
1
  .idea/
2
  __pycache__/
3
+ .gradio
4
+ playground.py
5
+ resnet18-5c106cde.pth
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: Face Parser
3
  emoji: πŸ‘
4
- colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 3.34.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Face Parser
3
  emoji: πŸ‘
4
+ colorFrom: green
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.46.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Extracts facial features (hair, nose, eyes, etc.)
12
  ---
 
 
app.py CHANGED
@@ -1,23 +1,42 @@
1
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
  from PIL import Image
7
- from bisnet import BiSeNet
8
- from huggingface_hub import snapshot_download
9
 
 
10
  from utils import vis_parsing_maps, decode_segmentation_masks, image_to_tensor
11
 
12
- os.system("pip freeze")
13
-
14
  REPO_ID = "leonelhs/faceparser"
15
  MODEL_NAME = "79999_iter.pth"
16
 
17
  model = BiSeNet(n_classes=19)
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- snapshot_folder = snapshot_download(repo_id=REPO_ID)
20
- model_path = os.path.join(snapshot_folder, MODEL_NAME)
21
  model.load_state_dict(torch.load(model_path, map_location=device))
22
  model.eval()
23
 
@@ -47,33 +66,39 @@ def predict(image):
47
  return overlay
48
 
49
 
50
- title = "Face Parser"
51
- description = r"""
52
- ## Image face parser for research
53
 
54
- This is an implementation of <a href='https://github.com/zllrunning/face-parsing.PyTorch' target='_blank'>face-parsing.PyTorch</a>.
55
- It has no any particular purpose than start research on AI models.
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  """
58
 
59
- article = r"""
60
- Questions, doubts, comments, please email πŸ“§ `[email protected]`
 
 
 
 
 
 
 
61
 
62
- This demo is running on a CPU, if you like this project please make us a donation to run on a GPU or just give us a <a href='https://github.com/leonelhs/zeroscratches/' target='_blank'>Github ⭐</a>
63
 
64
- <a href="https://www.buymeacoffee.com/leonelhs"><img src="https://img.buymeacoffee.com/button-api/?text=Buy me a coffee&emoji=&slug=leonelhs&button_colour=FFDD00&font_colour=000000&font_family=Cookie&outline_colour=000000&coffee_colour=ffffff" /></a>
65
 
66
- <center><img src='https://visitor-badge.glitch.me/badge?page_id=zeroscratches.visitor-badge' alt='visitor badge'></center>
67
- """
68
 
69
- demo = gr.Interface(
70
- predict, [
71
- gr.Image(type="pil", label="Input"),
72
- ], [
73
- gr.Image(type="numpy", label="Image face parsed")
74
- ],
75
- title=title,
76
- description=description,
77
- article=article)
78
-
79
- demo.queue().launch()
 
1
+ # copies or substantial portions of the Software.
2
+ #
3
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
4
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
5
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
6
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
7
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
8
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
9
+ # SOFTWARE.
10
+ #
11
+ #######################################################################################
12
+ #
13
+ # This project is one of several repositories exploring image segmentation techniques.
14
+ # All related projects and interactive demos can be found at:
15
+ # https://huggingface.co/spaces/leonelhs/removators
16
+ # Self app: https://huggingface.co/spaces/leonelhs/rembg
17
+ #
18
+ # Source code is based on or inspired by several projects.
19
+ # For more details and proper attribution, please refer to the following resources:
20
+ #
21
+ # - [face-makeup.PyTorch] - [https://github.com/zllrunning/face-makeup.PyTorch]
22
+ # - [BiSeNet] [https://github.com/CoinCheung/BiSeNet]
23
 
24
  import gradio as gr
25
  import numpy as np
26
  import torch
27
  from PIL import Image
28
+ from huggingface_hub import hf_hub_download
 
29
 
30
+ from bisnet import BiSeNet
31
  from utils import vis_parsing_maps, decode_segmentation_masks, image_to_tensor
32
 
 
 
33
  REPO_ID = "leonelhs/faceparser"
34
  MODEL_NAME = "79999_iter.pth"
35
 
36
  model = BiSeNet(n_classes=19)
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+
39
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_NAME)
40
  model.load_state_dict(torch.load(model_path, map_location=device))
41
  model.eval()
42
 
 
66
  return overlay
67
 
68
 
69
+ aboutme = r"""
70
+ # PyTorch Image Face Parser
 
71
 
72
+ Extracts facial features (hair, nose, eyes, etc.) from images using image segmentation.
73
+
74
+ This project is part of a larger collection of repositories exploring image segmentation techniques.
75
+ Related projects and interactive demos are available at: [Removators](https://huggingface.co/spaces/leonelhs/removators)
76
+
77
+ ## Acknowledgments
78
+ The source code is based on or inspired by the following projects:
79
+ - [face-makeup.PyTorch](https://github.com/zllrunning/face-makeup.PyTorch)
80
+ - [BiSeNet](https://github.com/CoinCheung/BiSeNet)
81
+
82
+ ## Contact
83
+ For questions, comments, or feedback, please contact:
84
+ πŸ“§ [email protected]
85
 
86
  """
87
 
88
+ with gr.Blocks(title="Face Parser") as app:
89
+ navbar = gr.Navbar(visible=True, main_page_name="Workspace")
90
+ gr.Markdown("## Face Parser Tool")
91
+ with gr.Row():
92
+ with gr.Column(scale=1):
93
+ inp = gr.Image(type="pil", label="Upload Image")
94
+ btn_predict = gr.Button("Parse")
95
+ with gr.Column(scale=2):
96
+ out = gr.Image(type="pil", label="Output image")
97
 
98
+ btn_predict.click(predict, inputs=[inp], outputs=[out])
99
 
 
100
 
101
+ with app.route("About this", "/about"):
102
+ gr.Markdown(aboutme)
103
 
104
+ app.launch()
 
 
 
 
 
 
 
 
 
 
bisnet/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) [2025] [[email protected]]
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+ #
23
+ #######################################################################################
24
+ #
25
+ # This project is one of several repositories exploring image segmentation techniques.
26
+ # All related projects and interactive demos can be found at:
27
+ # https://huggingface.co/spaces/leonelhs/removatorsau
28
+ # Self app: https://huggingface.co/spaces/leonelhs/rembg
29
+ #
30
+ # Source code is based on or inspired by several projects.
31
+ # For more details and proper attribution, please refer to the following resources:
32
+ #
33
+ # - [face-makeup.PyTorch] - [https://github.com/zllrunning/face-makeup.PyTorch]
34
+ # - [BiSeNet] [https://github.com/CoinCheung/BiSeNet]
35
+
36
+ from .model import BiSeNet
37
+
38
+ __version__ = "1.0.1"
39
+
bisnet/model.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+ #
3
+ # Copyright (c) [2025] [[email protected]]
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+ #
23
+ #######################################################################################
24
+ #
25
+ # This project is one of several repositories exploring image segmentation techniques.
26
+ # All related projects and interactive demos can be found at:
27
+ # https://huggingface.co/spaces/leonelhs/removators
28
+ # Self app: https://huggingface.co/spaces/leonelhs/rembg
29
+ #
30
+ # Source code is based on or inspired by several projects.
31
+ # For more details and proper attribution, please refer to the following resources:
32
+ #
33
+ # - [face-makeup.PyTorch] - [https://github.com/zllrunning/face-makeup.PyTorch]
34
+ # - [BiSeNet] [https://github.com/CoinCheung/BiSeNet]
35
+
36
+ import torch
37
+ import torch.nn as nn
38
+ import torch.nn.functional as F
39
+
40
+ from .resnet import Resnet18
41
+
42
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
43
+
44
+
45
+ class ConvBNReLU(nn.Module):
46
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
47
+ super(ConvBNReLU, self).__init__()
48
+ self.conv = nn.Conv2d(in_chan,
49
+ out_chan,
50
+ kernel_size=ks,
51
+ stride=stride,
52
+ padding=padding,
53
+ bias=False)
54
+ self.bn = nn.BatchNorm2d(out_chan)
55
+ self.init_weight()
56
+
57
+ def forward(self, x):
58
+ x = self.conv(x)
59
+ x = F.relu(self.bn(x))
60
+ return x
61
+
62
+ def init_weight(self):
63
+ for ly in self.children():
64
+ if isinstance(ly, nn.Conv2d):
65
+ nn.init.kaiming_normal_(ly.weight, a=1)
66
+ if ly.bias is not None:
67
+ nn.init.constant_(ly.bias, 0)
68
+
69
+
70
+ class BiSeNetOutput(nn.Module):
71
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
72
+ super(BiSeNetOutput, self).__init__()
73
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
74
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
75
+ self.init_weight()
76
+
77
+ def forward(self, x):
78
+ x = self.conv(x)
79
+ x = self.conv_out(x)
80
+ return x
81
+
82
+ def init_weight(self):
83
+ for ly in self.children():
84
+ if isinstance(ly, nn.Conv2d):
85
+ nn.init.kaiming_normal_(ly.weight, a=1)
86
+ if ly.bias is not None:
87
+ nn.init.constant_(ly.bias, 0)
88
+
89
+ def get_params(self):
90
+ wd_params, nowd_params = [], []
91
+ for name, module in self.named_modules():
92
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
93
+ wd_params.append(module.weight)
94
+ if module.bias is not None:
95
+ nowd_params.append(module.bias)
96
+ elif isinstance(module, nn.BatchNorm2d):
97
+ nowd_params += list(module.parameters())
98
+ return wd_params, nowd_params
99
+
100
+
101
+ class AttentionRefinementModule(nn.Module):
102
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
103
+ super(AttentionRefinementModule, self).__init__()
104
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
105
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
106
+ self.bn_atten = nn.BatchNorm2d(out_chan)
107
+ self.sigmoid_atten = nn.Sigmoid()
108
+ self.init_weight()
109
+
110
+ def forward(self, x):
111
+ feat = self.conv(x)
112
+ atten = F.avg_pool2d(feat, feat.size()[2:])
113
+ atten = self.conv_atten(atten)
114
+ atten = self.bn_atten(atten)
115
+ atten = self.sigmoid_atten(atten)
116
+ return torch.mul(feat, atten)
117
+
118
+ def init_weight(self):
119
+ for ly in self.children():
120
+ if isinstance(ly, nn.Conv2d):
121
+ nn.init.kaiming_normal_(ly.weight, a=1)
122
+ if ly.bias is not None:
123
+ nn.init.constant_(ly.bias, 0)
124
+
125
+
126
+ class ContextPath(nn.Module):
127
+ def __init__(self, *args, **kwargs):
128
+ super(ContextPath, self).__init__()
129
+ self.resnet = Resnet18()
130
+ self.arm16 = AttentionRefinementModule(256, 128)
131
+ self.arm32 = AttentionRefinementModule(512, 128)
132
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
133
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
134
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
135
+
136
+ self.init_weight()
137
+
138
+ def forward(self, x):
139
+ H0, W0 = x.size()[2:]
140
+ feat8, feat16, feat32 = self.resnet(x)
141
+ H8, W8 = feat8.size()[2:]
142
+ H16, W16 = feat16.size()[2:]
143
+ H32, W32 = feat32.size()[2:]
144
+
145
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
146
+ avg = self.conv_avg(avg)
147
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
148
+
149
+ feat32_arm = self.arm32(feat32)
150
+ feat32_sum = feat32_arm + avg_up
151
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
152
+ feat32_up = self.conv_head32(feat32_up)
153
+
154
+ feat16_arm = self.arm16(feat16)
155
+ feat16_sum = feat16_arm + feat32_up
156
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
157
+ feat16_up = self.conv_head16(feat16_up)
158
+
159
+ return feat8, feat16_up, feat32_up # x8, x8, x16
160
+
161
+ def init_weight(self):
162
+ for ly in self.children():
163
+ if isinstance(ly, nn.Conv2d):
164
+ nn.init.kaiming_normal_(ly.weight, a=1)
165
+ if ly.bias is not None:
166
+ nn.init.constant_(ly.bias, 0)
167
+
168
+ def get_params(self):
169
+ wd_params, nowd_params = [], []
170
+ for name, module in self.named_modules():
171
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
172
+ wd_params.append(module.weight)
173
+ if module.bias is not None:
174
+ nowd_params.append(module.bias)
175
+ elif isinstance(module, nn.BatchNorm2d):
176
+ nowd_params += list(module.parameters())
177
+ return wd_params, nowd_params
178
+
179
+
180
+ # This is not used, since I replace this with the resnet feature with the same size
181
+ class SpatialPath(nn.Module):
182
+ def __init__(self, *args, **kwargs):
183
+ super(SpatialPath, self).__init__()
184
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
185
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
186
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
187
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
188
+ self.init_weight()
189
+
190
+ def forward(self, x):
191
+ feat = self.conv1(x)
192
+ feat = self.conv2(feat)
193
+ feat = self.conv3(feat)
194
+ feat = self.conv_out(feat)
195
+ return feat
196
+
197
+ def init_weight(self):
198
+ for ly in self.children():
199
+ if isinstance(ly, nn.Conv2d):
200
+ nn.init.kaiming_normal_(ly.weight, a=1)
201
+ if ly.bias is not None:
202
+ nn.init.constant_(ly.bias, 0)
203
+
204
+ def get_params(self):
205
+ wd_params, nowd_params = [], []
206
+ for name, module in self.named_modules():
207
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
208
+ wd_params.append(module.weight)
209
+ if module.bias is not None:
210
+ nowd_params.append(module.bias)
211
+ elif isinstance(module, nn.BatchNorm2d):
212
+ nowd_params += list(module.parameters())
213
+ return wd_params, nowd_params
214
+
215
+
216
+ class FeatureFusionModule(nn.Module):
217
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
218
+ super(FeatureFusionModule, self).__init__()
219
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
220
+ self.conv1 = nn.Conv2d(out_chan,
221
+ out_chan // 4,
222
+ kernel_size=1,
223
+ stride=1,
224
+ padding=0,
225
+ bias=False)
226
+ self.conv2 = nn.Conv2d(out_chan // 4,
227
+ out_chan,
228
+ kernel_size=1,
229
+ stride=1,
230
+ padding=0,
231
+ bias=False)
232
+ self.relu = nn.ReLU(inplace=True)
233
+ self.sigmoid = nn.Sigmoid()
234
+ self.init_weight()
235
+
236
+ def forward(self, fsp, fcp):
237
+ fcat = torch.cat([fsp, fcp], dim=1)
238
+ feat = self.convblk(fcat)
239
+ atten = F.avg_pool2d(feat, feat.size()[2:])
240
+ atten = self.conv1(atten)
241
+ atten = self.relu(atten)
242
+ atten = self.conv2(atten)
243
+ atten = self.sigmoid(atten)
244
+ feat_atten = torch.mul(feat, atten)
245
+ feat_out = feat_atten + feat
246
+ return feat_out
247
+
248
+ def init_weight(self):
249
+ for ly in self.children():
250
+ if isinstance(ly, nn.Conv2d):
251
+ nn.init.kaiming_normal_(ly.weight, a=1)
252
+ if ly.bias is not None:
253
+ nn.init.constant_(ly.bias, 0)
254
+
255
+ def get_params(self):
256
+ wd_params, nowd_params = [], []
257
+ for name, module in self.named_modules():
258
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
259
+ wd_params.append(module.weight)
260
+ if module.bias is not None:
261
+ nowd_params.append(module.bias)
262
+ elif isinstance(module, nn.BatchNorm2d):
263
+ nowd_params += list(module.parameters())
264
+ return wd_params, nowd_params
265
+
266
+
267
+ class BiSeNet(nn.Module):
268
+ def __init__(self, n_classes, *args, **kwargs):
269
+ super(BiSeNet, self).__init__()
270
+ self.cp = ContextPath()
271
+ # here self.sp is deleted
272
+ self.ffm = FeatureFusionModule(256, 256)
273
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
274
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
275
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
276
+ self.init_weight()
277
+
278
+ def forward(self, x):
279
+ H, W = x.size()[2:]
280
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
281
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
282
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
283
+
284
+ feat_out = self.conv_out(feat_fuse)
285
+ feat_out16 = self.conv_out16(feat_cp8)
286
+ feat_out32 = self.conv_out32(feat_cp16)
287
+
288
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
289
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
290
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
291
+ return feat_out, feat_out16, feat_out32
292
+
293
+ def init_weight(self):
294
+ for ly in self.children():
295
+ if isinstance(ly, nn.Conv2d):
296
+ nn.init.kaiming_normal_(ly.weight, a=1)
297
+ if ly.bias is not None:
298
+ nn.init.constant_(ly.bias, 0)
299
+
300
+ def get_params(self):
301
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
302
+ for name, child in self.named_children():
303
+ child_wd_params, child_nowd_params = child.get_params()
304
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
305
+ lr_mul_wd_params += child_wd_params
306
+ lr_mul_nowd_params += child_nowd_params
307
+ else:
308
+ wd_params += child_wd_params
309
+ nowd_params += child_nowd_params
310
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
311
+
312
+
313
+ if __name__ == "__main__":
314
+ net = BiSeNet(19)
315
+ net.cuda()
316
+ net.eval()
317
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
318
+ out, out16, out32 = net(in_ten)
319
+ print(out.shape)
320
+
321
+ net.get_params()
bisnet/resnet.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################################################
2
+ #
3
+ # MIT License
4
+ #
5
+ # Copyright (c) [2025] [[email protected]]
6
+ #
7
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ # of this software and associated documentation files (the "Software"), to deal
9
+ # in the Software without restriction, including without limitation the rights
10
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ # copies of the Software, and to permit persons to whom the Software is
12
+ # furnished to do so, subject to the following conditions:
13
+ #
14
+ # The above copyright notice and this permission notice shall be included in all
15
+ # copies or substantial portions of the Software.
16
+ #
17
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ # SOFTWARE.
24
+ #
25
+ #######################################################################################
26
+ #
27
+ # This project is one of several repositories exploring image segmentation techniques.
28
+ # All related projects and interactive demos can be found at:
29
+ # https://huggingface.co/spaces/leonelhs/removators
30
+ # Self app: https://huggingface.co/spaces/leonelhs/rembg
31
+ #
32
+ # Source code is based on or inspired by several projects.
33
+ # For more details and proper attribution, please refer to the following resources:
34
+ #
35
+ # - [face-makeup.PyTorch] - [https://github.com/zllrunning/face-makeup.PyTorch]
36
+ # - [BiSeNet] [https://github.com/CoinCheung/BiSeNet]
37
+
38
+ import torch
39
+ import torch.nn as nn
40
+ import torch.nn.functional as F
41
+ from huggingface_hub import hf_hub_download
42
+
43
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
44
+
45
+ REPO_ID = "leonelhs/faceparser"
46
+ CKPT = "resnet18-5c106cde.pth"
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+
49
+ def conv3x3(in_planes, out_planes, stride=1):
50
+ """3x3 convolution with padding"""
51
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
52
+ padding=1, bias=False)
53
+
54
+
55
+ class BasicBlock(nn.Module):
56
+ def __init__(self, in_chan, out_chan, stride=1):
57
+ super(BasicBlock, self).__init__()
58
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
59
+ self.bn1 = nn.BatchNorm2d(out_chan)
60
+ self.conv2 = conv3x3(out_chan, out_chan)
61
+ self.bn2 = nn.BatchNorm2d(out_chan)
62
+ self.relu = nn.ReLU(inplace=True)
63
+ self.downsample = None
64
+ if in_chan != out_chan or stride != 1:
65
+ self.downsample = nn.Sequential(
66
+ nn.Conv2d(in_chan, out_chan,
67
+ kernel_size=1, stride=stride, bias=False),
68
+ nn.BatchNorm2d(out_chan),
69
+ )
70
+
71
+ def forward(self, x):
72
+ residual = self.conv1(x)
73
+ residual = F.relu(self.bn1(residual))
74
+ residual = self.conv2(residual)
75
+ residual = self.bn2(residual)
76
+
77
+ shortcut = x
78
+ if self.downsample is not None:
79
+ shortcut = self.downsample(x)
80
+
81
+ out = shortcut + residual
82
+ out = self.relu(out)
83
+ return out
84
+
85
+
86
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
87
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
88
+ for i in range(bnum - 1):
89
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
90
+ return nn.Sequential(*layers)
91
+
92
+
93
+ class Resnet18(nn.Module):
94
+ def __init__(self):
95
+ super(Resnet18, self).__init__()
96
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
97
+ bias=False)
98
+ self.bn1 = nn.BatchNorm2d(64)
99
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
100
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
101
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
102
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
103
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
104
+ self.init_weight()
105
+
106
+ def forward(self, x):
107
+ x = self.conv1(x)
108
+ x = F.relu(self.bn1(x))
109
+ x = self.maxpool(x)
110
+
111
+ x = self.layer1(x)
112
+ feat8 = self.layer2(x) # 1/8
113
+ feat16 = self.layer3(feat8) # 1/16
114
+ feat32 = self.layer4(feat16) # 1/32
115
+ return feat8, feat16, feat32
116
+
117
+ def init_weight(self):
118
+ checkpoint = hf_hub_download(repo_id=REPO_ID, filename=CKPT)
119
+ state_dict = torch.load(checkpoint, map_location=device, weights_only=False)
120
+ self_state_dict = self.state_dict()
121
+ for k, v in state_dict.items():
122
+ if 'fc' in k: continue
123
+ self_state_dict.update({k: v})
124
+ self.load_state_dict(self_state_dict)
125
+
126
+ def get_params(self):
127
+ wd_params, nowd_params = [], []
128
+ for name, module in self.named_modules():
129
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
130
+ wd_params.append(module.weight)
131
+ if not module.bias is None:
132
+ nowd_params.append(module.bias)
133
+ elif isinstance(module, nn.BatchNorm2d):
134
+ nowd_params += list(module.parameters())
135
+ return wd_params, nowd_params
136
+
137
+
138
+ if __name__ == "__main__":
139
+ net = Resnet18()
140
+ x = torch.randn(16, 3, 224, 224)
141
+ out = net(x)
142
+ print(out[0].size())
143
+ print(out[1].size())
144
+ print(out[2].size())
145
+ net.get_params()
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
- torch>=2.0.1
2
- torchvision~=0.15.2
3
- pillow~=9.5.0
4
- bisnet~=1.0.1
5
- opencv-python
 
 
 
1
+ torch>=2.8.0
2
+ torchvision>=0.23.0
3
+ opencv-python-headless>=4.12.0.88
4
+ gradio~=5.46.1
5
+ numpy~=2.1.2
6
+ pillow~=11.0.0
7
+ huggingface-hub~=0.35.0