init
Browse files- APDrawingGAN2/data/aligned_dataset.py +109 -93
- APDrawingGAN2/data/single_dataset.py +13 -12
- APDrawingGAN2/models/apdrawingpp_style_model.py +216 -181
- APDrawingGAN2/models/base_model.py +149 -109
APDrawingGAN2/data/aligned_dataset.py
CHANGED
|
@@ -9,50 +9,55 @@ import numpy as np
|
|
| 9 |
import cv2
|
| 10 |
import csv
|
| 11 |
|
|
|
|
| 12 |
def getfeats(featpath):
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
def tocv2(ts):
|
| 21 |
-
img = (ts.numpy()/2+0.5)*255
|
| 22 |
img = img.astype('uint8')
|
| 23 |
-
img = np.transpose(img,(1,2,0))
|
| 24 |
-
img = img[
|
| 25 |
return img
|
| 26 |
|
|
|
|
| 27 |
def dt(img):
|
| 28 |
-
if(img.shape[2]==3):
|
| 29 |
-
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
|
| 30 |
-
#convert to BW
|
| 31 |
-
ret1,thresh1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY)
|
| 32 |
-
ret2,thresh2 = cv2.threshold(img,127,255,cv2.THRESH_BINARY_INV)
|
| 33 |
-
dt1 = cv2.distanceTransform(thresh1,cv2.DIST_L2,5)
|
| 34 |
-
dt2 = cv2.distanceTransform(thresh2,cv2.DIST_L2,5)
|
| 35 |
-
dt1 = dt1/dt1.max()
|
| 36 |
-
dt2 = dt2/dt2.max()
|
| 37 |
return dt1, dt2
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
| 42 |
cxdists = []
|
| 43 |
cydists = []
|
| 44 |
for i in range(len(xb)):
|
| 45 |
-
xba = np.tile(xb[i],(size[1],1)).transpose()
|
| 46 |
-
yba = np.tile(yb[i],(size[0],1))
|
| 47 |
-
cxdists.append(np.abs(xarray-xba))
|
| 48 |
-
cydists.append(np.abs(yarray-yba))
|
| 49 |
xdist = np.minimum.reduce(cxdists)
|
| 50 |
ydist = np.minimum.reduce(cydists)
|
| 51 |
-
manhdist = np.minimum.reduce([xdist,ydist])
|
| 52 |
-
im = (manhdist+1) / (boundwidth+1) * 1.0
|
| 53 |
-
im[im>=1.0] = 1.0
|
| 54 |
return im
|
| 55 |
|
|
|
|
| 56 |
class AlignedDataset(BaseDataset):
|
| 57 |
@staticmethod
|
| 58 |
def modify_commandline_options(parser, is_train):
|
|
@@ -71,17 +76,17 @@ class AlignedDataset(BaseDataset):
|
|
| 71 |
else:
|
| 72 |
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
|
| 73 |
self.AB_paths = sorted(make_dataset(self.dir_AB))
|
| 74 |
-
assert(opt.resize_or_crop == 'resize_and_crop')
|
| 75 |
|
| 76 |
def __getitem__(self, index):
|
| 77 |
AB_path = self.AB_paths[index]
|
| 78 |
AB = Image.open(AB_path).convert('RGB')
|
| 79 |
w, h = AB.size
|
| 80 |
-
if w/h == 2:
|
| 81 |
w2 = int(w / 2)
|
| 82 |
A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
| 83 |
B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
| 84 |
-
else:
|
| 85 |
A = AB.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
| 86 |
B = Image.open(self.B_paths[index]).convert('RGB')
|
| 87 |
B = B.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
|
@@ -90,7 +95,7 @@ class AlignedDataset(BaseDataset):
|
|
| 90 |
w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
|
| 91 |
h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
|
| 92 |
|
| 93 |
-
A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]#C,H,W
|
| 94 |
B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
|
| 95 |
|
| 96 |
A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
|
|
@@ -118,25 +123,25 @@ class AlignedDataset(BaseDataset):
|
|
| 118 |
if output_nc == 1: # RGB to gray
|
| 119 |
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
|
| 120 |
B = tmp.unsqueeze(0)
|
| 121 |
-
|
| 122 |
item = {'A': A, 'B': B,
|
| 123 |
'A_paths': AB_path, 'B_paths': AB_path}
|
| 124 |
|
| 125 |
if self.opt.use_local:
|
| 126 |
-
regions = ['eyel','eyer','nose','mouth']
|
| 127 |
-
basen = os.path.basename(AB_path)[:-4]+'.txt'
|
| 128 |
-
if self.opt.region_enm in [0,1]:
|
| 129 |
featdir = self.opt.lm_dir
|
| 130 |
-
featpath = os.path.join(featdir,basen)
|
| 131 |
feats = getfeats(featpath)
|
| 132 |
if flipped:
|
| 133 |
for i in range(5):
|
| 134 |
-
feats[i,0] = self.opt.fineSize - feats[i,0] - 1
|
| 135 |
-
tmp = [feats[0,0],feats[0,1]]
|
| 136 |
-
feats[0
|
| 137 |
-
feats[1
|
| 138 |
-
mouth_x = int((feats[3,0]+feats[4,0])/2.0)
|
| 139 |
-
mouth_y = int((feats[3,1]+feats[4,1])/2.0)
|
| 140 |
ratio = self.opt.fineSize / 256
|
| 141 |
EYE_H = self.opt.EYE_H * ratio
|
| 142 |
EYE_W = self.opt.EYE_W * ratio
|
|
@@ -144,32 +149,37 @@ class AlignedDataset(BaseDataset):
|
|
| 144 |
NOSE_W = self.opt.NOSE_W * ratio
|
| 145 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 146 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 147 |
-
center = torch.
|
|
|
|
|
|
|
| 148 |
item['center'] = center
|
| 149 |
-
rhs = [int(EYE_H),int(EYE_H),int(NOSE_H),int(MOUTH_H)]
|
| 150 |
-
rws = [int(EYE_W),int(EYE_W),int(NOSE_W),int(MOUTH_W)]
|
| 151 |
if self.opt.soft_border:
|
| 152 |
soft_border_mask4 = []
|
| 153 |
for i in range(4):
|
| 154 |
-
xb = [np.zeros(rhs[i]),np.ones(rhs[i])*(rws[i]-1)]
|
| 155 |
-
yb = [np.zeros(rws[i]),np.ones(rws[i])*(rhs[i]-1)]
|
| 156 |
-
soft_border_mask = getSoft([rhs[i],rws[i]],xb,yb)
|
| 157 |
soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0))
|
| 158 |
-
item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i]
|
| 159 |
for i in range(4):
|
| 160 |
-
item[regions[i]+'_A'] = A[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,
|
| 161 |
-
|
|
|
|
|
|
|
| 162 |
if self.opt.soft_border:
|
| 163 |
-
item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(
|
| 164 |
-
|
|
|
|
| 165 |
if self.opt.compactmask:
|
| 166 |
cmasks0 = []
|
| 167 |
cmasks = []
|
| 168 |
for i in range(4):
|
| 169 |
-
if flipped and i in [0,1]:
|
| 170 |
-
cmaskpath = os.path.join(self.opt.cmask_dir,regions[1-i],basen[:-4]+'.png')
|
| 171 |
else:
|
| 172 |
-
cmaskpath = os.path.join(self.opt.cmask_dir,regions[i],basen[:-4]+'.png')
|
| 173 |
im_cmask = Image.open(cmaskpath)
|
| 174 |
cmask0 = transforms.ToTensor()(im_cmask)
|
| 175 |
if flipped:
|
|
@@ -180,11 +190,12 @@ class AlignedDataset(BaseDataset):
|
|
| 180 |
cmask0 = (cmask0 >= 0.5).float()
|
| 181 |
cmasks0.append(cmask0)
|
| 182 |
cmask = cmask0.clone()
|
| 183 |
-
if self.opt.region_enm in [0,1]:
|
| 184 |
-
cmask = cmask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
item[regions[i]+'
|
|
|
|
| 188 |
cmasks.append(cmask)
|
| 189 |
item['cmaskel'] = cmasks[0]
|
| 190 |
item['cmasker'] = cmasks[1]
|
|
@@ -194,70 +205,75 @@ class AlignedDataset(BaseDataset):
|
|
| 194 |
mask = torch.ones(B.shape)
|
| 195 |
if self.opt.region_enm == 0:
|
| 196 |
for i in range(4):
|
| 197 |
-
mask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,
|
|
|
|
| 198 |
if self.opt.soft_border:
|
| 199 |
imgsize = self.opt.fineSize
|
| 200 |
maskn = mask[0].numpy()
|
| 201 |
-
masks = [np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),
|
|
|
|
| 202 |
masks[0][1:] = maskn[:-1]
|
| 203 |
masks[1][:-1] = maskn[1:]
|
| 204 |
-
masks[2][:,1:] = maskn[
|
| 205 |
-
masks[3][
|
| 206 |
-
masks2 = [maskn-e for e in masks]
|
| 207 |
bound = np.minimum.reduce(masks2)
|
| 208 |
bound = -bound
|
| 209 |
xb = []
|
| 210 |
yb = []
|
| 211 |
for i in range(4):
|
| 212 |
-
xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1]
|
| 213 |
-
ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1]
|
| 214 |
for j in range(2):
|
| 215 |
-
maskx = bound[:,xbi[j]]
|
| 216 |
-
masky = bound[ybi[j]
|
| 217 |
-
tmp_a = torch.from_numpy(maskx)*xbi[j]
|
| 218 |
-
tmp_b = torch.from_numpy(1-maskx)
|
| 219 |
-
xb += [tmp_b*10000 + tmp_a]
|
| 220 |
|
| 221 |
-
tmp_a = torch.from_numpy(masky)*ybi[j]
|
| 222 |
-
tmp_b = torch.from_numpy(1-masky)
|
| 223 |
-
yb += [tmp_b*10000 + tmp_a]
|
| 224 |
-
soft = 1-getSoft([imgsize,imgsize],xb,yb)
|
| 225 |
soft = torch.Tensor(soft).unsqueeze(0)
|
| 226 |
-
mask = (torch.ones(mask.shape)-mask)*soft + mask
|
| 227 |
elif self.opt.region_enm == 1:
|
| 228 |
for i in range(4):
|
| 229 |
cmask0 = cmasks0[i]
|
| 230 |
rec = torch.zeros(B.shape)
|
| 231 |
-
rec[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,
|
|
|
|
| 232 |
mask = mask * (torch.ones(B.shape) - cmask0 * rec)
|
| 233 |
elif self.opt.region_enm == 2:
|
| 234 |
for i in range(4):
|
| 235 |
cmask0 = cmasks0[i]
|
| 236 |
mask = mask * (torch.ones(B.shape) - cmask0)
|
| 237 |
-
hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
| 238 |
-
hair_B = (B/2+0.5) * mask * 2 - 1
|
| 239 |
item['hair_A'] = hair_A
|
| 240 |
item['hair_B'] = hair_B
|
| 241 |
-
item['mask'] = mask
|
| 242 |
if self.opt.bg_local:
|
| 243 |
bgdir = self.opt.bg_dir
|
| 244 |
-
bgpath = os.path.join(bgdir,basen[:-4]+'.png')
|
| 245 |
im_bg = Image.open(bgpath)
|
| 246 |
-
mask2 = transforms.ToTensor()(im_bg)
|
| 247 |
if flipped:
|
| 248 |
mask2 = mask2.index_select(2, idx)
|
| 249 |
mask2 = (mask2 >= 0.5).float()
|
| 250 |
-
hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * mask2.repeat(
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
|
|
|
|
|
|
| 254 |
item['hair_A'] = hair_A
|
| 255 |
item['hair_B'] = hair_B
|
| 256 |
item['bg_A'] = bg_A
|
| 257 |
item['bg_B'] = bg_B
|
| 258 |
item['mask'] = mask
|
| 259 |
item['mask2'] = mask2
|
| 260 |
-
|
| 261 |
if (self.opt.isTrain and self.opt.chamfer_loss):
|
| 262 |
if self.opt.which_direction == 'AtoB':
|
| 263 |
img = tocv2(B)
|
|
@@ -270,11 +286,11 @@ class AlignedDataset(BaseDataset):
|
|
| 270 |
dt2 = dt2.unsqueeze(0)
|
| 271 |
item['dt1gt'] = dt1
|
| 272 |
item['dt2gt'] = dt2
|
| 273 |
-
|
| 274 |
if self.opt.isTrain and self.opt.emphasis_conti_face:
|
| 275 |
-
face_mask_path = os.path.join(self.opt.facemask_dir,basen[:-4]+'.png')
|
| 276 |
face_mask = Image.open(face_mask_path)
|
| 277 |
-
face_mask = transforms.ToTensor()(face_mask)
|
| 278 |
if flipped:
|
| 279 |
face_mask = face_mask.index_select(2, idx)
|
| 280 |
item['face_mask'] = face_mask
|
|
|
|
| 9 |
import cv2
|
| 10 |
import csv
|
| 11 |
|
| 12 |
+
|
| 13 |
def getfeats(featpath):
|
| 14 |
+
trans_points = np.empty([5, 2], dtype=np.int64)
|
| 15 |
+
with open(featpath, 'r') as csvfile:
|
| 16 |
+
reader = csv.reader(csvfile, delimiter=' ')
|
| 17 |
+
for ind, row in enumerate(reader):
|
| 18 |
+
trans_points[ind, :] = row
|
| 19 |
+
return trans_points
|
| 20 |
+
|
| 21 |
|
| 22 |
def tocv2(ts):
|
| 23 |
+
img = (ts.numpy() / 2 + 0.5) * 255
|
| 24 |
img = img.astype('uint8')
|
| 25 |
+
img = np.transpose(img, (1, 2, 0))
|
| 26 |
+
img = img[:, :, ::-1] # rgb->bgr
|
| 27 |
return img
|
| 28 |
|
| 29 |
+
|
| 30 |
def dt(img):
|
| 31 |
+
if (img.shape[2] == 3):
|
| 32 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 33 |
+
# convert to BW
|
| 34 |
+
ret1, thresh1 = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY)
|
| 35 |
+
ret2, thresh2 = cv2.threshold(img, 127, 255, cv2.THRESH_BINARY_INV)
|
| 36 |
+
dt1 = cv2.distanceTransform(thresh1, cv2.DIST_L2, 5)
|
| 37 |
+
dt2 = cv2.distanceTransform(thresh2, cv2.DIST_L2, 5)
|
| 38 |
+
dt1 = dt1 / dt1.max() # ->[0,1]
|
| 39 |
+
dt2 = dt2 / dt2.max()
|
| 40 |
return dt1, dt2
|
| 41 |
|
| 42 |
+
|
| 43 |
+
def getSoft(size, xb, yb, boundwidth=5.0):
|
| 44 |
+
xarray = np.tile(np.arange(0, size[1]), (size[0], 1))
|
| 45 |
+
yarray = np.tile(np.arange(0, size[0]), (size[1], 1)).transpose()
|
| 46 |
cxdists = []
|
| 47 |
cydists = []
|
| 48 |
for i in range(len(xb)):
|
| 49 |
+
xba = np.tile(xb[i], (size[1], 1)).transpose()
|
| 50 |
+
yba = np.tile(yb[i], (size[0], 1))
|
| 51 |
+
cxdists.append(np.abs(xarray - xba))
|
| 52 |
+
cydists.append(np.abs(yarray - yba))
|
| 53 |
xdist = np.minimum.reduce(cxdists)
|
| 54 |
ydist = np.minimum.reduce(cydists)
|
| 55 |
+
manhdist = np.minimum.reduce([xdist, ydist])
|
| 56 |
+
im = (manhdist + 1) / (boundwidth + 1) * 1.0
|
| 57 |
+
im[im >= 1.0] = 1.0
|
| 58 |
return im
|
| 59 |
|
| 60 |
+
|
| 61 |
class AlignedDataset(BaseDataset):
|
| 62 |
@staticmethod
|
| 63 |
def modify_commandline_options(parser, is_train):
|
|
|
|
| 76 |
else:
|
| 77 |
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
|
| 78 |
self.AB_paths = sorted(make_dataset(self.dir_AB))
|
| 79 |
+
assert (opt.resize_or_crop == 'resize_and_crop')
|
| 80 |
|
| 81 |
def __getitem__(self, index):
|
| 82 |
AB_path = self.AB_paths[index]
|
| 83 |
AB = Image.open(AB_path).convert('RGB')
|
| 84 |
w, h = AB.size
|
| 85 |
+
if w / h == 2:
|
| 86 |
w2 = int(w / 2)
|
| 87 |
A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
| 88 |
B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
| 89 |
+
else: # if w/h != 2, need B_paths
|
| 90 |
A = AB.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
| 91 |
B = Image.open(self.B_paths[index]).convert('RGB')
|
| 92 |
B = B.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
|
|
|
|
| 95 |
w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
|
| 96 |
h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
|
| 97 |
|
| 98 |
+
A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] # C,H,W
|
| 99 |
B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
|
| 100 |
|
| 101 |
A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
|
|
|
|
| 123 |
if output_nc == 1: # RGB to gray
|
| 124 |
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
|
| 125 |
B = tmp.unsqueeze(0)
|
| 126 |
+
|
| 127 |
item = {'A': A, 'B': B,
|
| 128 |
'A_paths': AB_path, 'B_paths': AB_path}
|
| 129 |
|
| 130 |
if self.opt.use_local:
|
| 131 |
+
regions = ['eyel', 'eyer', 'nose', 'mouth']
|
| 132 |
+
basen = os.path.basename(AB_path)[:-4] + '.txt'
|
| 133 |
+
if self.opt.region_enm in [0, 1]:
|
| 134 |
featdir = self.opt.lm_dir
|
| 135 |
+
featpath = os.path.join(featdir, basen)
|
| 136 |
feats = getfeats(featpath)
|
| 137 |
if flipped:
|
| 138 |
for i in range(5):
|
| 139 |
+
feats[i, 0] = self.opt.fineSize - feats[i, 0] - 1
|
| 140 |
+
tmp = [feats[0, 0], feats[0, 1]]
|
| 141 |
+
feats[0, :] = [feats[1, 0], feats[1, 1]]
|
| 142 |
+
feats[1, :] = tmp
|
| 143 |
+
mouth_x = int((feats[3, 0] + feats[4, 0]) / 2.0)
|
| 144 |
+
mouth_y = int((feats[3, 1] + feats[4, 1]) / 2.0)
|
| 145 |
ratio = self.opt.fineSize / 256
|
| 146 |
EYE_H = self.opt.EYE_H * ratio
|
| 147 |
EYE_W = self.opt.EYE_W * ratio
|
|
|
|
| 149 |
NOSE_W = self.opt.NOSE_W * ratio
|
| 150 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 151 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 152 |
+
center = torch.LongTensor(
|
| 153 |
+
[[feats[0, 0], feats[0, 1] - 4 * ratio], [feats[1, 0], feats[1, 1] - 4 * ratio],
|
| 154 |
+
[feats[2, 0], feats[2, 1] - NOSE_H / 2 + 16 * ratio], [mouth_x, mouth_y]])
|
| 155 |
item['center'] = center
|
| 156 |
+
rhs = [int(EYE_H), int(EYE_H), int(NOSE_H), int(MOUTH_H)]
|
| 157 |
+
rws = [int(EYE_W), int(EYE_W), int(NOSE_W), int(MOUTH_W)]
|
| 158 |
if self.opt.soft_border:
|
| 159 |
soft_border_mask4 = []
|
| 160 |
for i in range(4):
|
| 161 |
+
xb = [np.zeros(rhs[i]), np.ones(rhs[i]) * (rws[i] - 1)]
|
| 162 |
+
yb = [np.zeros(rws[i]), np.ones(rws[i]) * (rhs[i] - 1)]
|
| 163 |
+
soft_border_mask = getSoft([rhs[i], rws[i]], xb, yb)
|
| 164 |
soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0))
|
| 165 |
+
item['soft_' + regions[i] + '_mask'] = soft_border_mask4[i]
|
| 166 |
for i in range(4):
|
| 167 |
+
item[regions[i] + '_A'] = A[:, int(center[i, 1] - rhs[i] / 2):int(center[i, 1] + rhs[i] / 2),
|
| 168 |
+
int(center[i, 0] - rws[i] / 2):int(center[i, 0] + rws[i] / 2)]
|
| 169 |
+
item[regions[i] + '_B'] = B[:, int(center[i, 1] - rhs[i] / 2):int(center[i, 1] + rhs[i] / 2),
|
| 170 |
+
int(center[i, 0] - rws[i] / 2):int(center[i, 0] + rws[i] / 2)]
|
| 171 |
if self.opt.soft_border:
|
| 172 |
+
item[regions[i] + '_A'] = item[regions[i] + '_A'] * soft_border_mask4[i].repeat(
|
| 173 |
+
int(input_nc / output_nc), 1, 1)
|
| 174 |
+
item[regions[i] + '_B'] = item[regions[i] + '_B'] * soft_border_mask4[i]
|
| 175 |
if self.opt.compactmask:
|
| 176 |
cmasks0 = []
|
| 177 |
cmasks = []
|
| 178 |
for i in range(4):
|
| 179 |
+
if flipped and i in [0, 1]:
|
| 180 |
+
cmaskpath = os.path.join(self.opt.cmask_dir, regions[1 - i], basen[:-4] + '.png')
|
| 181 |
else:
|
| 182 |
+
cmaskpath = os.path.join(self.opt.cmask_dir, regions[i], basen[:-4] + '.png')
|
| 183 |
im_cmask = Image.open(cmaskpath)
|
| 184 |
cmask0 = transforms.ToTensor()(im_cmask)
|
| 185 |
if flipped:
|
|
|
|
| 190 |
cmask0 = (cmask0 >= 0.5).float()
|
| 191 |
cmasks0.append(cmask0)
|
| 192 |
cmask = cmask0.clone()
|
| 193 |
+
if self.opt.region_enm in [0, 1]:
|
| 194 |
+
cmask = cmask[:, int(center[i, 1] - rhs[i] / 2):int(center[i, 1] + rhs[i] / 2),
|
| 195 |
+
int(center[i, 0] - rws[i] / 2):int(center[i, 0] + rws[i] / 2)]
|
| 196 |
+
elif self.opt.region_enm in [2]: # need to multiply cmask
|
| 197 |
+
item[regions[i] + '_A'] = (A / 2 + 0.5) * cmask * 2 - 1
|
| 198 |
+
item[regions[i] + '_B'] = (B / 2 + 0.5) * cmask * 2 - 1
|
| 199 |
cmasks.append(cmask)
|
| 200 |
item['cmaskel'] = cmasks[0]
|
| 201 |
item['cmasker'] = cmasks[1]
|
|
|
|
| 205 |
mask = torch.ones(B.shape)
|
| 206 |
if self.opt.region_enm == 0:
|
| 207 |
for i in range(4):
|
| 208 |
+
mask[:, int(center[i, 1] - rhs[i] / 2):int(center[i, 1] + rhs[i] / 2),
|
| 209 |
+
int(center[i, 0] - rws[i] / 2):int(center[i, 0] + rws[i] / 2)] = 0
|
| 210 |
if self.opt.soft_border:
|
| 211 |
imgsize = self.opt.fineSize
|
| 212 |
maskn = mask[0].numpy()
|
| 213 |
+
masks = [np.ones([imgsize, imgsize]), np.ones([imgsize, imgsize]), np.ones([imgsize, imgsize]),
|
| 214 |
+
np.ones([imgsize, imgsize])]
|
| 215 |
masks[0][1:] = maskn[:-1]
|
| 216 |
masks[1][:-1] = maskn[1:]
|
| 217 |
+
masks[2][:, 1:] = maskn[:, :-1]
|
| 218 |
+
masks[3][:, :-1] = maskn[:, 1:]
|
| 219 |
+
masks2 = [maskn - e for e in masks]
|
| 220 |
bound = np.minimum.reduce(masks2)
|
| 221 |
bound = -bound
|
| 222 |
xb = []
|
| 223 |
yb = []
|
| 224 |
for i in range(4):
|
| 225 |
+
xbi = [int(center[i, 0] - rws[i] / 2), int(center[i, 0] + rws[i] / 2 - 1)]
|
| 226 |
+
ybi = [int(center[i, 1] - rhs[i] / 2), int(center[i, 1] + rhs[i] / 2 - 1)]
|
| 227 |
for j in range(2):
|
| 228 |
+
maskx = bound[:, xbi[j]]
|
| 229 |
+
masky = bound[ybi[j], :]
|
| 230 |
+
tmp_a = torch.from_numpy(maskx) * xbi[j]
|
| 231 |
+
tmp_b = torch.from_numpy(1 - maskx)
|
| 232 |
+
xb += [tmp_b * 10000 + tmp_a]
|
| 233 |
|
| 234 |
+
tmp_a = torch.from_numpy(masky) * ybi[j]
|
| 235 |
+
tmp_b = torch.from_numpy(1 - masky)
|
| 236 |
+
yb += [tmp_b * 10000 + tmp_a]
|
| 237 |
+
soft = 1 - getSoft([imgsize, imgsize], xb, yb)
|
| 238 |
soft = torch.Tensor(soft).unsqueeze(0)
|
| 239 |
+
mask = (torch.ones(mask.shape) - mask) * soft + mask
|
| 240 |
elif self.opt.region_enm == 1:
|
| 241 |
for i in range(4):
|
| 242 |
cmask0 = cmasks0[i]
|
| 243 |
rec = torch.zeros(B.shape)
|
| 244 |
+
rec[:, int(center[i, 1] - rhs[i] / 2):int(center[i, 1] + rhs[i] / 2),
|
| 245 |
+
int(center[i, 0] - rws[i] / 2):int(center[i, 0] + rws[i] / 2)] = 1
|
| 246 |
mask = mask * (torch.ones(B.shape) - cmask0 * rec)
|
| 247 |
elif self.opt.region_enm == 2:
|
| 248 |
for i in range(4):
|
| 249 |
cmask0 = cmasks0[i]
|
| 250 |
mask = mask * (torch.ones(B.shape) - cmask0)
|
| 251 |
+
hair_A = (A / 2 + 0.5) * mask.repeat(int(input_nc / output_nc), 1, 1) * 2 - 1
|
| 252 |
+
hair_B = (B / 2 + 0.5) * mask * 2 - 1
|
| 253 |
item['hair_A'] = hair_A
|
| 254 |
item['hair_B'] = hair_B
|
| 255 |
+
item['mask'] = mask # mask out eyes, nose, mouth
|
| 256 |
if self.opt.bg_local:
|
| 257 |
bgdir = self.opt.bg_dir
|
| 258 |
+
bgpath = os.path.join(bgdir, basen[:-4] + '.png')
|
| 259 |
im_bg = Image.open(bgpath)
|
| 260 |
+
mask2 = transforms.ToTensor()(im_bg) # mask out background
|
| 261 |
if flipped:
|
| 262 |
mask2 = mask2.index_select(2, idx)
|
| 263 |
mask2 = (mask2 >= 0.5).float()
|
| 264 |
+
hair_A = (A / 2 + 0.5) * mask.repeat(int(input_nc / output_nc), 1, 1) * mask2.repeat(
|
| 265 |
+
int(input_nc / output_nc), 1, 1) * 2 - 1
|
| 266 |
+
hair_B = (B / 2 + 0.5) * mask * mask2 * 2 - 1
|
| 267 |
+
bg_A = (A / 2 + 0.5) * (torch.ones(mask2.shape) - mask2).repeat(int(input_nc / output_nc), 1,
|
| 268 |
+
1) * 2 - 1
|
| 269 |
+
bg_B = (B / 2 + 0.5) * (torch.ones(mask2.shape) - mask2) * 2 - 1
|
| 270 |
item['hair_A'] = hair_A
|
| 271 |
item['hair_B'] = hair_B
|
| 272 |
item['bg_A'] = bg_A
|
| 273 |
item['bg_B'] = bg_B
|
| 274 |
item['mask'] = mask
|
| 275 |
item['mask2'] = mask2
|
| 276 |
+
|
| 277 |
if (self.opt.isTrain and self.opt.chamfer_loss):
|
| 278 |
if self.opt.which_direction == 'AtoB':
|
| 279 |
img = tocv2(B)
|
|
|
|
| 286 |
dt2 = dt2.unsqueeze(0)
|
| 287 |
item['dt1gt'] = dt1
|
| 288 |
item['dt2gt'] = dt2
|
| 289 |
+
|
| 290 |
if self.opt.isTrain and self.opt.emphasis_conti_face:
|
| 291 |
+
face_mask_path = os.path.join(self.opt.facemask_dir, basen[:-4] + '.png')
|
| 292 |
face_mask = Image.open(face_mask_path)
|
| 293 |
+
face_mask = transforms.ToTensor()(face_mask) # [0,1]
|
| 294 |
if flipped:
|
| 295 |
face_mask = face_mask.index_select(2, idx)
|
| 296 |
item['face_mask'] = face_mask
|
APDrawingGAN2/data/single_dataset.py
CHANGED
|
@@ -82,7 +82,7 @@ class SingleDataset(BaseDataset):
|
|
| 82 |
NOSE_W = self.opt.NOSE_W * ratio
|
| 83 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 84 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 85 |
-
center = torch.
|
| 86 |
item['center'] = center
|
| 87 |
rhs = [int(EYE_H),int(EYE_H),int(NOSE_H),int(MOUTH_H)]
|
| 88 |
rws = [int(EYE_W),int(EYE_W),int(NOSE_W),int(MOUTH_W)]
|
|
@@ -95,7 +95,10 @@ class SingleDataset(BaseDataset):
|
|
| 95 |
soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0))
|
| 96 |
item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i]
|
| 97 |
for i in range(4):
|
| 98 |
-
item[regions[i]+'_A'] = A[:,
|
|
|
|
|
|
|
|
|
|
| 99 |
if self.opt.soft_border:
|
| 100 |
item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(int(input_nc/output_nc),1,1)
|
| 101 |
if self.opt.compactmask:
|
|
@@ -111,7 +114,7 @@ class SingleDataset(BaseDataset):
|
|
| 111 |
cmask0 = (cmask0 >= 0.5).float()
|
| 112 |
cmasks0.append(cmask0)
|
| 113 |
cmask = cmask0.clone()
|
| 114 |
-
cmask = cmask[:,
|
| 115 |
cmasks.append(cmask)
|
| 116 |
item['cmaskel'] = cmasks[0]
|
| 117 |
item['cmasker'] = cmasks[1]
|
|
@@ -121,7 +124,7 @@ class SingleDataset(BaseDataset):
|
|
| 121 |
output_nc = self.opt.output_nc
|
| 122 |
mask = torch.ones([output_nc,A.shape[1],A.shape[2]])
|
| 123 |
for i in range(4):
|
| 124 |
-
mask[:,
|
| 125 |
if self.opt.soft_border:
|
| 126 |
imgsize = self.opt.fineSize
|
| 127 |
maskn = mask[0].numpy()
|
|
@@ -136,11 +139,11 @@ class SingleDataset(BaseDataset):
|
|
| 136 |
xb = []
|
| 137 |
yb = []
|
| 138 |
for i in range(4):
|
| 139 |
-
xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1]
|
| 140 |
-
ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1]
|
| 141 |
for j in range(2):
|
| 142 |
-
maskx = bound[:,
|
| 143 |
-
masky = bound[
|
| 144 |
tmp_a = torch.from_numpy(maskx)*xbi[j].double()
|
| 145 |
tmp_b = torch.from_numpy(1-maskx)
|
| 146 |
xb += [tmp_b*10000 + tmp_a]
|
|
@@ -160,10 +163,8 @@ class SingleDataset(BaseDataset):
|
|
| 160 |
im_bg = Image.open(bgpath)
|
| 161 |
mask2 = transforms.ToTensor()(im_bg) # mask out background
|
| 162 |
mask2 = (mask2 >= 0.5).float()
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
#bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
| 166 |
-
bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(3,1,1) * 2 - 1
|
| 167 |
item['hair_A'] = hair_A
|
| 168 |
item['bg_A'] = bg_A
|
| 169 |
item['mask'] = mask
|
|
|
|
| 82 |
NOSE_W = self.opt.NOSE_W * ratio
|
| 83 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 84 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 85 |
+
center = torch.LongTensor([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-NOSE_H/2+16*ratio],[mouth_x,mouth_y]])
|
| 86 |
item['center'] = center
|
| 87 |
rhs = [int(EYE_H),int(EYE_H),int(NOSE_H),int(MOUTH_H)]
|
| 88 |
rws = [int(EYE_W),int(EYE_W),int(NOSE_W),int(MOUTH_W)]
|
|
|
|
| 95 |
soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0))
|
| 96 |
item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i]
|
| 97 |
for i in range(4):
|
| 98 |
+
item[regions[i]+'_A'] = A[:,(center[i,1]-rhs[i]/2).to(torch.long):
|
| 99 |
+
(center[i,1]+rhs[i]/2).to(torch.long),
|
| 100 |
+
(center[i,0]-rws[i]/2).to(torch.long):
|
| 101 |
+
(center[i,0]+rws[i]/2).to(torch.long)]
|
| 102 |
if self.opt.soft_border:
|
| 103 |
item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(int(input_nc/output_nc),1,1)
|
| 104 |
if self.opt.compactmask:
|
|
|
|
| 114 |
cmask0 = (cmask0 >= 0.5).float()
|
| 115 |
cmasks0.append(cmask0)
|
| 116 |
cmask = cmask0.clone()
|
| 117 |
+
cmask = cmask[:,(center[i,1]-rhs[i]/2).to(torch.long):(center[i,1]+rhs[i]/2).to(torch.long),(center[i,0]-rws[i]/2).to(torch.long):(center[i,0]+rws[i]/2).to(torch.long)]
|
| 118 |
cmasks.append(cmask)
|
| 119 |
item['cmaskel'] = cmasks[0]
|
| 120 |
item['cmasker'] = cmasks[1]
|
|
|
|
| 124 |
output_nc = self.opt.output_nc
|
| 125 |
mask = torch.ones([output_nc,A.shape[1],A.shape[2]])
|
| 126 |
for i in range(4):
|
| 127 |
+
mask[:,(center[i,1]-rhs[i]/2).to(torch.long):(center[i,1]+rhs[i]/2).to(torch.long),(center[i,0]-rws[i]/2).to(torch.long):(center[i,0]+rws[i]/2).to(torch.long)] = 0
|
| 128 |
if self.opt.soft_border:
|
| 129 |
imgsize = self.opt.fineSize
|
| 130 |
maskn = mask[0].numpy()
|
|
|
|
| 139 |
xb = []
|
| 140 |
yb = []
|
| 141 |
for i in range(4):
|
| 142 |
+
xbi = [(center[i,0]-rws[i]/2).to(torch.long), (center[i,0]+rws[i]/2-1).to(torch.long)]
|
| 143 |
+
ybi = [(center[i,1]-rhs[i]/2).to(torch.long), (center[i,1]+rhs[i]/2-1).to(torch.long)]
|
| 144 |
for j in range(2):
|
| 145 |
+
maskx = bound[:,xbi[j]]
|
| 146 |
+
masky = bound[ybi[j],:]
|
| 147 |
tmp_a = torch.from_numpy(maskx)*xbi[j].double()
|
| 148 |
tmp_b = torch.from_numpy(1-maskx)
|
| 149 |
xb += [tmp_b*10000 + tmp_a]
|
|
|
|
| 163 |
im_bg = Image.open(bgpath)
|
| 164 |
mask2 = transforms.ToTensor()(im_bg) # mask out background
|
| 165 |
mask2 = (mask2 >= 0.5).float()
|
| 166 |
+
hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * mask2.repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
| 167 |
+
bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(int(input_nc/output_nc),1,1) * 2 - 1
|
|
|
|
|
|
|
| 168 |
item['hair_A'] = hair_A
|
| 169 |
item['bg_A'] = bg_A
|
| 170 |
item['mask'] = mask
|
APDrawingGAN2/models/apdrawingpp_style_model.py
CHANGED
|
@@ -6,13 +6,13 @@ import os
|
|
| 6 |
import math
|
| 7 |
|
| 8 |
W = 11
|
| 9 |
-
aa = int(math.floor(512
|
| 10 |
-
res = 512 - W*aa
|
| 11 |
|
| 12 |
|
| 13 |
-
def padpart(A,part,centers,opt,device):
|
| 14 |
IMAGE_SIZE = opt.fineSize
|
| 15 |
-
bs,nc,_,_ = A.shape
|
| 16 |
ratio = IMAGE_SIZE / 256
|
| 17 |
NOSE_W = opt.NOSE_W * ratio
|
| 18 |
NOSE_H = opt.NOSE_H * ratio
|
|
@@ -20,37 +20,52 @@ def padpart(A,part,centers,opt,device):
|
|
| 20 |
EYE_H = opt.EYE_H * ratio
|
| 21 |
MOUTH_W = opt.MOUTH_W * ratio
|
| 22 |
MOUTH_H = opt.MOUTH_H * ratio
|
| 23 |
-
A_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(device)
|
| 24 |
-
padvalue = -1
|
| 25 |
for i in range(bs):
|
| 26 |
center = centers[i]
|
| 27 |
if part == 'nose':
|
| 28 |
-
A_p[i] = torch.nn.ConstantPad2d((
|
|
|
|
|
|
|
|
|
|
| 29 |
elif part == 'eyel':
|
| 30 |
-
A_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,
|
|
|
|
|
|
|
| 31 |
elif part == 'eyer':
|
| 32 |
-
A_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W
|
|
|
|
|
|
|
| 33 |
elif part == 'mouth':
|
| 34 |
-
A_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2
|
|
|
|
|
|
|
|
|
|
| 35 |
return A_p
|
| 36 |
|
|
|
|
| 37 |
import numpy as np
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
if type == 'atan':
|
| 40 |
-
nldt = torch.atan(dt*xmax) / torch.atan(xmax)
|
| 41 |
elif type == 'sigmoid':
|
| 42 |
-
nldt = (torch.sigmoid(dt*xmax)-0.5) / (torch.sigmoid(xmax)-0.5)
|
| 43 |
elif type == 'tanh':
|
| 44 |
-
nldt = torch.tanh(dt*xmax) / torch.tanh(xmax)
|
| 45 |
elif type == 'pow':
|
| 46 |
-
nldt = torch.pow(dt*xmax,2) / torch.pow(xmax,2)
|
| 47 |
elif type == 'exp':
|
| 48 |
-
if xmax.item()>1:
|
| 49 |
xmax = xmax / 3
|
| 50 |
-
nldt = (torch.exp(dt*xmax)-1) / (torch.exp(xmax)-1)
|
| 51 |
-
#print("remap dt:", type, xmax.item())
|
| 52 |
return nldt
|
| 53 |
|
|
|
|
| 54 |
class APDrawingPPStyleModel(BaseModel):
|
| 55 |
def name(self):
|
| 56 |
return 'APDrawingPPStyleModel'
|
|
@@ -60,7 +75,7 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 60 |
|
| 61 |
# changing the default values to match the pix2pix paper
|
| 62 |
# (https://phillipi.github.io/pix2pix/)
|
| 63 |
-
parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')# no_lsgan=True, use_lsgan=False
|
| 64 |
parser.set_defaults(dataset_mode='aligned')
|
| 65 |
parser.set_defaults(auxiliary_root='auxiliaryeye2o')
|
| 66 |
parser.set_defaults(use_local=True, hair_local=True, bg_local=True)
|
|
@@ -107,15 +122,15 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 107 |
self.visual_names += ['fake_B0', 'fake_B1']
|
| 108 |
self.visual_names += ['fake_B_hair', 'real_B_hair', 'real_A_hair']
|
| 109 |
self.visual_names += ['fake_B_bg', 'real_B_bg', 'real_A_bg']
|
| 110 |
-
if self.opt.region_enm in [0,1]:
|
| 111 |
if self.opt.nose_ae:
|
| 112 |
-
self.visual_names += ['fake_B_nose_v','fake_B_nose_v1','fake_B_nose_v2','cmask1no']
|
| 113 |
if self.opt.others_ae:
|
| 114 |
-
self.visual_names += ['fake_B_eyel_v','fake_B_eyel_v1','fake_B_eyel_v2','cmask1el']
|
| 115 |
-
self.visual_names += ['fake_B_eyer_v','fake_B_eyer_v1','fake_B_eyer_v2','cmask1er']
|
| 116 |
-
self.visual_names += ['fake_B_mouth_v','fake_B_mouth_v1','fake_B_mouth_v2','cmask1mo']
|
| 117 |
elif self.opt.region_enm in [2]:
|
| 118 |
-
self.visual_names += ['fake_B_nose','fake_B_eyel','fake_B_eyer','fake_B_mouth']
|
| 119 |
if self.isTrain and self.opt.chamfer_loss:
|
| 120 |
self.visual_names += ['dt1', 'dt2']
|
| 121 |
self.visual_names += ['dt1gt', 'dt2gt']
|
|
@@ -129,7 +144,7 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 129 |
if self.isTrain:
|
| 130 |
self.model_names = ['G', 'D']
|
| 131 |
if self.opt.discriminator_local:
|
| 132 |
-
self.model_names += ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
|
| 133 |
# auxiliary nets for loss calculation
|
| 134 |
if self.opt.chamfer_loss:
|
| 135 |
self.auxiliary_model_names += ['DT1', 'DT2']
|
|
@@ -141,13 +156,13 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 141 |
if self.opt.test_continuity_loss:
|
| 142 |
self.auxiliary_model_names += ['Regressor']
|
| 143 |
if self.opt.use_local:
|
| 144 |
-
self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine']
|
| 145 |
-
self.auxiliary_model_names += ['CLm','CLh']
|
| 146 |
# auxiliary nets for local output refinement
|
| 147 |
if self.opt.nose_ae:
|
| 148 |
self.auxiliary_model_names += ['AE']
|
| 149 |
if self.opt.others_ae:
|
| 150 |
-
self.auxiliary_model_names += ['AEel','AEer','AEmowhite','AEmoblack']
|
| 151 |
print('model_names', self.model_names)
|
| 152 |
print('auxiliary_model_names', self.auxiliary_model_names)
|
| 153 |
# load/define networks
|
|
@@ -159,55 +174,61 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 159 |
if self.isTrain:
|
| 160 |
use_sigmoid = opt.no_lsgan
|
| 161 |
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 162 |
-
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
|
|
|
|
| 163 |
print('netD', opt.netD, opt.n_layers_D)
|
| 164 |
if self.opt.discriminator_local:
|
| 165 |
self.netDLEyel = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 166 |
-
|
|
|
|
| 167 |
self.netDLEyer = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 168 |
-
|
|
|
|
| 169 |
self.netDLNose = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 170 |
-
|
|
|
|
| 171 |
self.netDLMouth = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 172 |
-
|
|
|
|
| 173 |
self.netDLHair = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 174 |
-
|
|
|
|
| 175 |
self.netDLBG = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
if self.opt.use_local:
|
| 180 |
netlocal1 = 'partunet' if self.opt.use_resnet == 0 else 'resnet_nblocks'
|
| 181 |
netlocal2 = 'partunet2' if self.opt.use_resnet == 0 else 'resnet_6blocks'
|
| 182 |
netlocal2_style = 'partunet2style' if self.opt.use_resnet == 0 else 'resnet_style2_6blocks'
|
| 183 |
self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
| 184 |
-
|
| 185 |
self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
| 186 |
-
|
| 187 |
self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
| 188 |
-
|
| 189 |
self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
| 190 |
-
|
| 191 |
self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2_style, opt.norm,
|
| 192 |
-
|
| 193 |
-
|
| 194 |
self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2, opt.norm,
|
| 195 |
-
|
| 196 |
# by default combiner_type is combiner, which uses resnet
|
| 197 |
print('combiner_type', self.opt.combiner_type)
|
| 198 |
-
self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, self.opt.combiner_type,
|
| 199 |
-
|
|
|
|
| 200 |
# auxiliary classifiers for mouth and hair
|
| 201 |
ratio = self.opt.fineSize / 256
|
| 202 |
self.MOUTH_H = int(self.opt.MOUTH_H * ratio)
|
| 203 |
self.MOUTH_W = int(self.opt.MOUTH_W * ratio)
|
| 204 |
self.netCLm = networks.define_G(opt.input_nc, 2, opt.ngf, 'classifier', opt.norm,
|
| 205 |
-
|
| 206 |
-
|
| 207 |
self.netCLh = networks.define_G(opt.input_nc, 3, opt.ngf, 'classifier', opt.norm,
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
|
| 212 |
if self.isTrain:
|
| 213 |
self.fake_AB_pool = ImagePool(opt.pool_size)
|
|
@@ -220,34 +241,39 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 220 |
if not self.opt.use_local:
|
| 221 |
print('G_params 1 components')
|
| 222 |
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
|
| 223 |
-
|
| 224 |
else:
|
| 225 |
-
G_params = list(self.netG.parameters()) + list(self.netGLEyel.parameters()) + list(
|
|
|
|
|
|
|
|
|
|
| 226 |
print('G_params 8 components')
|
| 227 |
self.optimizer_G = torch.optim.Adam(G_params,
|
| 228 |
-
|
| 229 |
-
|
| 230 |
if not self.opt.discriminator_local:
|
| 231 |
print('D_params 1 components')
|
| 232 |
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
|
| 233 |
-
|
| 234 |
-
else
|
| 235 |
-
D_params = list(self.netD.parameters()) + list(self.netDLEyel.parameters()) +
|
|
|
|
|
|
|
| 236 |
print('D_params 7 components')
|
| 237 |
self.optimizer_D = torch.optim.Adam(D_params,
|
| 238 |
-
|
| 239 |
self.optimizers.append(self.optimizer_G)
|
| 240 |
self.optimizers.append(self.optimizer_D)
|
| 241 |
-
|
| 242 |
# ==================================auxiliary nets (loaded, parameters fixed)=============================
|
| 243 |
if self.opt.use_local and self.opt.nose_ae:
|
| 244 |
ratio = self.opt.fineSize / 256
|
| 245 |
NOSE_H = self.opt.NOSE_H * ratio
|
| 246 |
NOSE_W = self.opt.NOSE_W * ratio
|
| 247 |
self.netAE = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
self.set_requires_grad(self.netAE, False)
|
| 251 |
if self.opt.use_local and self.opt.others_ae:
|
| 252 |
ratio = self.opt.fineSize / 256
|
| 253 |
EYE_H = self.opt.EYE_H * ratio
|
|
@@ -255,53 +281,51 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 255 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 256 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 257 |
self.netAEel = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 258 |
-
|
| 259 |
-
|
| 260 |
self.netAEer = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 261 |
-
|
| 262 |
-
|
| 263 |
self.netAEmowhite = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 264 |
-
|
| 265 |
-
|
| 266 |
self.netAEmoblack = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 267 |
-
|
| 268 |
-
|
| 269 |
self.set_requires_grad(self.netAEel, False)
|
| 270 |
self.set_requires_grad(self.netAEer, False)
|
| 271 |
self.set_requires_grad(self.netAEmowhite, False)
|
| 272 |
self.set_requires_grad(self.netAEmoblack, False)
|
| 273 |
-
|
| 274 |
-
|
| 275 |
if self.isTrain and self.opt.continuity_loss:
|
| 276 |
self.nc = 1
|
| 277 |
self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm,
|
| 278 |
-
|
| 279 |
-
|
| 280 |
self.set_requires_grad(self.netRegressor, False)
|
| 281 |
|
| 282 |
if self.isTrain and self.opt.chamfer_loss:
|
| 283 |
self.nc = 1
|
| 284 |
self.netDT1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm,
|
| 285 |
-
|
| 286 |
self.netDT2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm,
|
| 287 |
-
|
| 288 |
self.set_requires_grad(self.netDT1, False)
|
| 289 |
self.set_requires_grad(self.netDT2, False)
|
| 290 |
self.netLine1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm,
|
| 291 |
-
|
| 292 |
self.netLine2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm,
|
| 293 |
-
|
| 294 |
self.set_requires_grad(self.netLine1, False)
|
| 295 |
self.set_requires_grad(self.netLine2, False)
|
| 296 |
-
|
| 297 |
# ==================================for test (nets loaded, parameters fixed)=============================
|
| 298 |
-
if
|
| 299 |
self.nc = 1
|
| 300 |
self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm,
|
| 301 |
-
|
| 302 |
-
|
| 303 |
self.set_requires_grad(self.netRegressor, False)
|
| 304 |
-
|
| 305 |
|
| 306 |
def set_input(self, input):
|
| 307 |
AtoB = self.opt.which_direction == 'AtoB'
|
|
@@ -318,7 +342,7 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 318 |
self.real_B_eyer = input['eyer_B'].to(self.device)
|
| 319 |
self.real_B_nose = input['nose_B'].to(self.device)
|
| 320 |
self.real_B_mouth = input['mouth_B'].to(self.device)
|
| 321 |
-
if self.opt.region_enm in [0,1]:
|
| 322 |
self.center = input['center']
|
| 323 |
if self.opt.soft_border:
|
| 324 |
self.softel = input['soft_eyel_mask'].to(self.device)
|
|
@@ -327,17 +351,17 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 327 |
self.softmo = input['soft_mouth_mask'].to(self.device)
|
| 328 |
if self.opt.compactmask:
|
| 329 |
self.cmask = input['cmask'].to(self.device)
|
| 330 |
-
self.cmask1 = self.cmask*2-1#[0,1]->[-1,1]
|
| 331 |
self.cmaskel = input['cmaskel'].to(self.device)
|
| 332 |
-
self.cmask1el = self.cmaskel*2-1
|
| 333 |
self.cmasker = input['cmasker'].to(self.device)
|
| 334 |
-
self.cmask1er = self.cmasker*2-1
|
| 335 |
self.cmaskmo = input['cmaskmo'].to(self.device)
|
| 336 |
-
self.cmask1mo = self.cmaskmo*2-1
|
| 337 |
self.real_A_hair = input['hair_A'].to(self.device)
|
| 338 |
self.real_B_hair = input['hair_B'].to(self.device)
|
| 339 |
-
self.mask = input['mask'].to(self.device)
|
| 340 |
-
self.mask2 = input['mask2'].to(self.device)
|
| 341 |
self.real_A_bg = input['bg_A'].to(self.device)
|
| 342 |
self.real_B_bg = input['bg_B'].to(self.device)
|
| 343 |
if (self.isTrain and self.opt.chamfer_loss):
|
|
@@ -345,13 +369,13 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 345 |
self.dt2gt = input['dt2gt'].to(self.device)
|
| 346 |
if self.isTrain and self.opt.emphasis_conti_face:
|
| 347 |
self.face_mask = input['face_mask'].cuda(self.gpu_ids_p[0])
|
| 348 |
-
|
| 349 |
-
def getonehot(self,outputs,classes):
|
| 350 |
-
[maxv,index] = torch.max(outputs,1)
|
| 351 |
-
y = torch.unsqueeze(index,1)
|
| 352 |
-
onehot = torch.FloatTensor(self.batch_size,classes).to(self.device)
|
| 353 |
onehot.zero_()
|
| 354 |
-
onehot.scatter_(1,y,1)
|
| 355 |
return onehot
|
| 356 |
|
| 357 |
def forward(self):
|
|
@@ -361,40 +385,41 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 361 |
self.fake_B0 = self.netG(self.real_A)
|
| 362 |
# EYES, MOUTH
|
| 363 |
outputs1 = self.netCLm(self.real_A_mouth)
|
| 364 |
-
onehot1 = self.getonehot(outputs1,2)
|
| 365 |
|
| 366 |
if not self.opt.others_ae:
|
| 367 |
fake_B_eyel = self.netGLEyel(self.real_A_eyel)
|
| 368 |
fake_B_eyer = self.netGLEyer(self.real_A_eyer)
|
| 369 |
fake_B_mouth = self.netGLMouth(self.real_A_mouth)
|
| 370 |
-
else:
|
| 371 |
self.fake_B_eyel1 = self.netGLEyel(self.real_A_eyel)
|
| 372 |
self.fake_B_eyer1 = self.netGLEyer(self.real_A_eyer)
|
| 373 |
self.fake_B_mouth1 = self.netGLMouth(self.real_A_mouth)
|
| 374 |
-
self.fake_B_eyel2,_ = self.netAEel(self.fake_B_eyel1)
|
| 375 |
-
self.fake_B_eyer2,_ = self.netAEer(self.fake_B_eyer1)
|
| 376 |
# USE 2 AEs
|
| 377 |
-
self.fake_B_mouth2 = torch.FloatTensor(self.batch_size,self.opt.output_nc,self.MOUTH_H,
|
|
|
|
| 378 |
for i in range(self.batch_size):
|
| 379 |
if onehot1[i][0] == 1:
|
| 380 |
-
self.fake_B_mouth2[i],_ = self.netAEmowhite(self.fake_B_mouth1[i].unsqueeze(0))
|
| 381 |
-
#print('AEmowhite')
|
| 382 |
elif onehot1[i][1] == 1:
|
| 383 |
-
self.fake_B_mouth2[i],_ = self.netAEmoblack(self.fake_B_mouth1[i].unsqueeze(0))
|
| 384 |
-
#print('AEmoblack')
|
| 385 |
-
fake_B_eyel = self.add_with_mask(self.fake_B_eyel2,self.fake_B_eyel1,self.cmaskel)
|
| 386 |
-
fake_B_eyer = self.add_with_mask(self.fake_B_eyer2,self.fake_B_eyer1,self.cmasker)
|
| 387 |
-
fake_B_mouth = self.add_with_mask(self.fake_B_mouth2,self.fake_B_mouth1,self.cmaskmo)
|
| 388 |
# NOSE
|
| 389 |
if not self.opt.nose_ae:
|
| 390 |
fake_B_nose = self.netGLNose(self.real_A_nose)
|
| 391 |
-
else:
|
| 392 |
self.fake_B_nose1 = self.netGLNose(self.real_A_nose)
|
| 393 |
-
self.fake_B_nose2,_ = self.netAE(self.fake_B_nose1)
|
| 394 |
-
fake_B_nose = self.add_with_mask(self.fake_B_nose2,self.fake_B_nose1,self.cmask)
|
| 395 |
|
| 396 |
# for visuals and later local loss
|
| 397 |
-
if self.opt.region_enm in [0,1]:
|
| 398 |
self.fake_B_nose = fake_B_nose
|
| 399 |
self.fake_B_eyel = fake_B_eyel
|
| 400 |
self.fake_B_eyer = fake_B_eyer
|
|
@@ -405,41 +430,48 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 405 |
self.fake_B_eyel = self.masked(fake_B_eyel, self.softel)
|
| 406 |
self.fake_B_eyer = self.masked(fake_B_eyer, self.softer)
|
| 407 |
self.fake_B_mouth = self.masked(fake_B_mouth, self.softmo)
|
| 408 |
-
elif self.opt.region_enm in [2]:
|
| 409 |
-
self.fake_B_nose = self.masked(fake_B_nose,self.cmask)
|
| 410 |
-
self.fake_B_eyel = self.masked(fake_B_eyel,self.cmaskel)
|
| 411 |
-
self.fake_B_eyer = self.masked(fake_B_eyer,self.cmasker)
|
| 412 |
-
self.fake_B_mouth = self.masked(fake_B_mouth,self.cmaskmo)
|
| 413 |
-
|
| 414 |
# HAIR, BG AND PARTCOMBINE
|
| 415 |
outputs2 = self.netCLh(self.real_A_hair)
|
| 416 |
-
onehot2 = self.getonehot(outputs2,3)
|
| 417 |
|
| 418 |
if not self.isTrain:
|
| 419 |
opt = self.opt
|
| 420 |
if opt.imagefolder == 'images':
|
| 421 |
-
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch),
|
|
|
|
| 422 |
else:
|
| 423 |
-
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch),
|
| 424 |
-
|
| 425 |
-
|
|
|
|
| 426 |
with open(file_name, 'a+') as s_file:
|
| 427 |
s_file.write(message)
|
| 428 |
s_file.write('\n')
|
| 429 |
|
| 430 |
-
fake_B_hair = self.netGLHair(self.real_A_hair,onehot2)
|
| 431 |
fake_B_bg = self.netGLBG(self.real_A_bg)
|
| 432 |
-
self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2)
|
| 433 |
-
self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2))
|
| 434 |
if not self.opt.compactmask:
|
| 435 |
-
self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,
|
|
|
|
|
|
|
| 436 |
else:
|
| 437 |
-
self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,
|
| 438 |
-
|
| 439 |
-
|
|
|
|
|
|
|
|
|
|
| 440 |
|
| 441 |
# for AE visuals
|
| 442 |
-
if self.opt.region_enm in [0,1]:
|
| 443 |
if self.opt.nose_ae:
|
| 444 |
self.fake_B_nose_v = padpart(self.fake_B_nose, 'nose', self.center, self.opt, self.device)
|
| 445 |
self.fake_B_nose_v1 = padpart(self.fake_B_nose1, 'nose', self.center, self.opt, self.device)
|
|
@@ -458,21 +490,20 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 458 |
self.fake_B_mouth_v1 = padpart(self.fake_B_mouth1, 'mouth', self.center, self.opt, self.device)
|
| 459 |
self.fake_B_mouth_v2 = padpart(self.fake_B_mouth2, 'mouth', self.center, self.opt, self.device)
|
| 460 |
self.cmask1mo = padpart(self.cmask1mo, 'mouth', self.center, self.opt, self.device)
|
| 461 |
-
|
| 462 |
if not self.isTrain and self.opt.test_continuity_loss:
|
| 463 |
self.ContinuityForTest(real=1)
|
| 464 |
-
|
| 465 |
-
|
| 466 |
def backward_D(self):
|
| 467 |
# Fake
|
| 468 |
# stop backprop to the generator by detaching fake_B
|
| 469 |
fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
|
| 470 |
-
#print('fake_AB', fake_AB.shape) # (1,4,512,512)
|
| 471 |
-
pred_fake = self.netD(fake_AB.detach())# by detach, not affect G's gradient
|
| 472 |
self.loss_D_fake = self.criterionGAN(pred_fake, False)
|
| 473 |
if self.opt.discriminator_local:
|
| 474 |
fake_AB_parts = self.getLocalParts(fake_AB)
|
| 475 |
-
local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
|
| 476 |
self.loss_D_fake_local = 0
|
| 477 |
for i in range(len(fake_AB_parts)):
|
| 478 |
net = getattr(self, 'net' + local_names[i])
|
|
@@ -487,7 +518,7 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 487 |
self.loss_D_real = self.criterionGAN(pred_real, True)
|
| 488 |
if self.opt.discriminator_local:
|
| 489 |
real_AB_parts = self.getLocalParts(real_AB)
|
| 490 |
-
local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
|
| 491 |
self.loss_D_real_local = 0
|
| 492 |
for i in range(len(real_AB_parts)):
|
| 493 |
net = getattr(self, 'net' + local_names[i])
|
|
@@ -504,12 +535,12 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 504 |
def backward_G(self):
|
| 505 |
# First, G(A) should fake the discriminator
|
| 506 |
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
|
| 507 |
-
pred_fake = self.netD(fake_AB)
|
| 508 |
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
|
| 509 |
if self.opt.discriminator_local:
|
| 510 |
fake_AB_parts = self.getLocalParts(fake_AB)
|
| 511 |
-
local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG']
|
| 512 |
-
self.loss_G_GAN_local = 0
|
| 513 |
for i in range(len(fake_AB_parts)):
|
| 514 |
net = getattr(self, 'net' + local_names[i])
|
| 515 |
pred_fake_tmp = net(fake_AB_parts[i])
|
|
@@ -524,31 +555,34 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 524 |
# Second, G(A) = B
|
| 525 |
if not self.opt.no_l1_loss:
|
| 526 |
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
|
| 527 |
-
|
| 528 |
if self.opt.use_local and not self.opt.no_G_local_loss:
|
| 529 |
-
local_names = ['eyel','eyer','nose','mouth']
|
| 530 |
self.loss_G_local = 0
|
| 531 |
for i in range(len(local_names)):
|
| 532 |
fakeblocal = getattr(self, 'fake_B_' + local_names[i])
|
| 533 |
realblocal = getattr(self, 'real_B_' + local_names[i])
|
| 534 |
addw = self.getaddw(local_names[i])
|
| 535 |
-
self.loss_G_local = self.loss_G_local + self.criterionL1(fakeblocal,
|
| 536 |
-
|
| 537 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 538 |
|
| 539 |
# Third, chamfer matching (assume chamfer_2way and chamfer_only_line is true)
|
| 540 |
if self.opt.chamfer_loss:
|
| 541 |
if self.fake_B.shape[1] == 3:
|
| 542 |
-
tmp = self.fake_B[:,0
|
| 543 |
fake_B_gray = tmp.unsqueeze(1)
|
| 544 |
else:
|
| 545 |
fake_B_gray = self.fake_B
|
| 546 |
if self.real_B.shape[1] == 3:
|
| 547 |
-
tmp = self.real_B[:,0
|
| 548 |
real_B_gray = tmp.unsqueeze(1)
|
| 549 |
else:
|
| 550 |
real_B_gray = self.real_B
|
| 551 |
-
|
| 552 |
gpu_p = self.opt.gpu_ids_p[0]
|
| 553 |
gpu = self.opt.gpu_ids[0]
|
| 554 |
if gpu_p != gpu:
|
|
@@ -558,22 +592,23 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 558 |
# d_CM(a_i,G(p_i))
|
| 559 |
self.dt1 = self.netDT1(fake_B_gray)
|
| 560 |
self.dt2 = self.netDT2(fake_B_gray)
|
| 561 |
-
dt1 = self.dt1/2.0+0.5#[-1,1]->[0,1]
|
| 562 |
-
dt2 = self.dt2/2.0+0.5
|
| 563 |
if self.opt.dt_nonlinear != '':
|
| 564 |
dt_xmax = torch.Tensor([self.opt.dt_xmax]).cuda(gpu_p)
|
| 565 |
dt1 = nonlinearDt(dt1, self.opt.dt_nonlinear, dt_xmax)
|
| 566 |
dt2 = nonlinearDt(dt2, self.opt.dt_nonlinear, dt_xmax)
|
| 567 |
-
#print('dt1dt2',torch.min(dt1).item(),torch.max(dt1).item(),torch.min(dt2).item(),torch.max(dt2).item())
|
| 568 |
-
|
| 569 |
bs = real_B_gray.shape[0]
|
| 570 |
real_B_gray_line1 = self.netLine1(real_B_gray)
|
| 571 |
real_B_gray_line2 = self.netLine2(real_B_gray)
|
| 572 |
-
self.loss_G_chamfer = (dt1[(real_B_gray<0)&(real_B_gray_line1<0)].sum() + dt2[
|
|
|
|
| 573 |
if gpu_p != gpu:
|
| 574 |
-
self.loss_G_chamfer = self.loss_G_chamfer.cuda(gpu)
|
| 575 |
|
| 576 |
-
|
| 577 |
if gpu_p != gpu:
|
| 578 |
dt1gt = self.dt1gt.cuda(gpu_p)
|
| 579 |
dt2gt = self.dt2gt.cuda(gpu_p)
|
|
@@ -583,13 +618,14 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 583 |
if self.opt.dt_nonlinear != '':
|
| 584 |
dt1gt = nonlinearDt(dt1gt, self.opt.dt_nonlinear, dt_xmax)
|
| 585 |
dt2gt = nonlinearDt(dt2gt, self.opt.dt_nonlinear, dt_xmax)
|
| 586 |
-
#print('dt1gtdt2gt',torch.min(dt1gt).item(),torch.max(dt1gt).item(),torch.min(dt2gt).item(),torch.max(dt2gt).item())
|
| 587 |
-
self.dt1gt = (self.dt1gt-0.5)*2
|
| 588 |
-
self.dt2gt = (self.dt2gt-0.5)*2
|
| 589 |
|
| 590 |
fake_B_gray_line1 = self.netLine1(fake_B_gray)
|
| 591 |
fake_B_gray_line2 = self.netLine2(fake_B_gray)
|
| 592 |
-
self.loss_G_chamfer2 = (dt1gt[(fake_B_gray<0)&(fake_B_gray_line1<0)].sum() + dt2gt[
|
|
|
|
| 593 |
if gpu_p != gpu:
|
| 594 |
self.loss_G_chamfer2 = self.loss_G_chamfer2.cuda(gpu)
|
| 595 |
|
|
@@ -599,11 +635,10 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 599 |
self.get_patches()
|
| 600 |
self.outputs = self.netRegressor(self.fake_B_patches)
|
| 601 |
if not self.opt.emphasis_conti_face:
|
| 602 |
-
self.loss_G_continuity = (1.0-torch.mean(self.outputs)).cuda(gpu) * self.opt.lambda_continuity
|
| 603 |
else:
|
| 604 |
-
self.loss_G_continuity = torch.mean((1.0-self.outputs)*self.conti_weights).cuda(
|
| 605 |
-
|
| 606 |
-
|
| 607 |
|
| 608 |
self.loss_G = self.loss_G_GAN
|
| 609 |
if 'G_L1' in self.loss_names:
|
|
@@ -627,7 +662,7 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 627 |
self.forward()
|
| 628 |
# update D
|
| 629 |
self.set_requires_grad(self.netD, True)
|
| 630 |
-
|
| 631 |
if self.opt.discriminator_local:
|
| 632 |
self.set_requires_grad(self.netDLEyel, True)
|
| 633 |
self.set_requires_grad(self.netDLEyer, True)
|
|
@@ -661,32 +696,32 @@ class APDrawingPPStyleModel(BaseModel):
|
|
| 661 |
patches = []
|
| 662 |
if self.isTrain and self.opt.emphasis_conti_face:
|
| 663 |
weights = []
|
| 664 |
-
W2 = int(W/2)
|
| 665 |
-
t = np.random.randint(res,size=2)
|
| 666 |
for i in range(aa):
|
| 667 |
for j in range(aa):
|
| 668 |
-
p = self.fake_B[
|
| 669 |
-
whitenum = torch.sum(p>=0.0)
|
| 670 |
-
#if whitenum < 5 or whitenum > W*W-5:
|
| 671 |
-
if whitenum < 1 or whitenum > W*W-1:
|
| 672 |
continue
|
| 673 |
patches.append(p)
|
| 674 |
if self.isTrain and self.opt.emphasis_conti_face:
|
| 675 |
-
weights.append(self.face_mask[
|
| 676 |
self.fake_B_patches = torch.cat(patches, dim=0)
|
| 677 |
if self.isTrain and self.opt.emphasis_conti_face:
|
| 678 |
-
self.conti_weights = torch.cat(weights, dim=0)+1
|
| 679 |
-
|
| 680 |
def get_patches_real(self):
|
| 681 |
# [1,1,512,512]->[bs,1,11,11]
|
| 682 |
patches = []
|
| 683 |
-
t = np.random.randint(res,size=2)
|
| 684 |
for i in range(aa):
|
| 685 |
for j in range(aa):
|
| 686 |
-
p = self.real_B[
|
| 687 |
-
whitenum = torch.sum(p>=0.0)
|
| 688 |
-
#if whitenum < 5 or whitenum > W*W-5:
|
| 689 |
-
if whitenum < 1 or whitenum > W*W-1:
|
| 690 |
continue
|
| 691 |
patches.append(p)
|
| 692 |
self.real_B_patches = torch.cat(patches, dim=0)
|
|
|
|
| 6 |
import math
|
| 7 |
|
| 8 |
W = 11
|
| 9 |
+
aa = int(math.floor(512. / W))
|
| 10 |
+
res = 512 - W * aa
|
| 11 |
|
| 12 |
|
| 13 |
+
def padpart(A, part, centers, opt, device):
|
| 14 |
IMAGE_SIZE = opt.fineSize
|
| 15 |
+
bs, nc, _, _ = A.shape
|
| 16 |
ratio = IMAGE_SIZE / 256
|
| 17 |
NOSE_W = opt.NOSE_W * ratio
|
| 18 |
NOSE_H = opt.NOSE_H * ratio
|
|
|
|
| 20 |
EYE_H = opt.EYE_H * ratio
|
| 21 |
MOUTH_W = opt.MOUTH_W * ratio
|
| 22 |
MOUTH_H = opt.MOUTH_H * ratio
|
| 23 |
+
A_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(device)
|
| 24 |
+
padvalue = -1 # black
|
| 25 |
for i in range(bs):
|
| 26 |
center = centers[i]
|
| 27 |
if part == 'nose':
|
| 28 |
+
A_p[i] = torch.nn.ConstantPad2d((
|
| 29 |
+
int(center[2, 0] - NOSE_W / 2), IMAGE_SIZE - int(center[2, 0] + NOSE_W / 2),
|
| 30 |
+
int(center[2, 1] - NOSE_H / 2),
|
| 31 |
+
IMAGE_SIZE - int(center[2, 1] + NOSE_H / 2)), padvalue)(A[i])
|
| 32 |
elif part == 'eyel':
|
| 33 |
+
A_p[i] = torch.nn.ConstantPad2d((int(center[0, 0] - EYE_W / 2), IMAGE_SIZE - int(center[0, 0] + EYE_W / 2),
|
| 34 |
+
int(center[0, 1] - EYE_H / 2), IMAGE_SIZE - int(center[0, 1] + EYE_H / 2)),
|
| 35 |
+
padvalue)(A[i])
|
| 36 |
elif part == 'eyer':
|
| 37 |
+
A_p[i] = torch.nn.ConstantPad2d((int(center[1, 0] - EYE_W / 2), IMAGE_SIZE - int(center[1, 0] + EYE_W / 2),
|
| 38 |
+
int(center[1, 1] - EYE_H / 2), IMAGE_SIZE - int(center[1, 1] + EYE_H / 2)),
|
| 39 |
+
padvalue)(A[i])
|
| 40 |
elif part == 'mouth':
|
| 41 |
+
A_p[i] = torch.nn.ConstantPad2d((int(center[3, 0] - MOUTH_W / 2),
|
| 42 |
+
IMAGE_SIZE - int(center[3, 0] + MOUTH_W / 2),
|
| 43 |
+
int(center[3, 1] - MOUTH_H / 2),
|
| 44 |
+
IMAGE_SIZE - int(center[3, 1] + MOUTH_H / 2)), padvalue)(A[i])
|
| 45 |
return A_p
|
| 46 |
|
| 47 |
+
|
| 48 |
import numpy as np
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def nonlinearDt(dt, type='atan',
|
| 52 |
+
xmax=torch.Tensor([10.0])): # dt in [0,1], first multiply xmax(>1), then remap to [0,1]
|
| 53 |
if type == 'atan':
|
| 54 |
+
nldt = torch.atan(dt * xmax) / torch.atan(xmax)
|
| 55 |
elif type == 'sigmoid':
|
| 56 |
+
nldt = (torch.sigmoid(dt * xmax) - 0.5) / (torch.sigmoid(xmax) - 0.5)
|
| 57 |
elif type == 'tanh':
|
| 58 |
+
nldt = torch.tanh(dt * xmax) / torch.tanh(xmax)
|
| 59 |
elif type == 'pow':
|
| 60 |
+
nldt = torch.pow(dt * xmax, 2) / torch.pow(xmax, 2)
|
| 61 |
elif type == 'exp':
|
| 62 |
+
if xmax.item() > 1:
|
| 63 |
xmax = xmax / 3
|
| 64 |
+
nldt = (torch.exp(dt * xmax) - 1) / (torch.exp(xmax) - 1)
|
| 65 |
+
# print("remap dt:", type, xmax.item())
|
| 66 |
return nldt
|
| 67 |
|
| 68 |
+
|
| 69 |
class APDrawingPPStyleModel(BaseModel):
|
| 70 |
def name(self):
|
| 71 |
return 'APDrawingPPStyleModel'
|
|
|
|
| 75 |
|
| 76 |
# changing the default values to match the pix2pix paper
|
| 77 |
# (https://phillipi.github.io/pix2pix/)
|
| 78 |
+
parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch') # no_lsgan=True, use_lsgan=False
|
| 79 |
parser.set_defaults(dataset_mode='aligned')
|
| 80 |
parser.set_defaults(auxiliary_root='auxiliaryeye2o')
|
| 81 |
parser.set_defaults(use_local=True, hair_local=True, bg_local=True)
|
|
|
|
| 122 |
self.visual_names += ['fake_B0', 'fake_B1']
|
| 123 |
self.visual_names += ['fake_B_hair', 'real_B_hair', 'real_A_hair']
|
| 124 |
self.visual_names += ['fake_B_bg', 'real_B_bg', 'real_A_bg']
|
| 125 |
+
if self.opt.region_enm in [0, 1]:
|
| 126 |
if self.opt.nose_ae:
|
| 127 |
+
self.visual_names += ['fake_B_nose_v', 'fake_B_nose_v1', 'fake_B_nose_v2', 'cmask1no']
|
| 128 |
if self.opt.others_ae:
|
| 129 |
+
self.visual_names += ['fake_B_eyel_v', 'fake_B_eyel_v1', 'fake_B_eyel_v2', 'cmask1el']
|
| 130 |
+
self.visual_names += ['fake_B_eyer_v', 'fake_B_eyer_v1', 'fake_B_eyer_v2', 'cmask1er']
|
| 131 |
+
self.visual_names += ['fake_B_mouth_v', 'fake_B_mouth_v1', 'fake_B_mouth_v2', 'cmask1mo']
|
| 132 |
elif self.opt.region_enm in [2]:
|
| 133 |
+
self.visual_names += ['fake_B_nose', 'fake_B_eyel', 'fake_B_eyer', 'fake_B_mouth']
|
| 134 |
if self.isTrain and self.opt.chamfer_loss:
|
| 135 |
self.visual_names += ['dt1', 'dt2']
|
| 136 |
self.visual_names += ['dt1gt', 'dt2gt']
|
|
|
|
| 144 |
if self.isTrain:
|
| 145 |
self.model_names = ['G', 'D']
|
| 146 |
if self.opt.discriminator_local:
|
| 147 |
+
self.model_names += ['DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG']
|
| 148 |
# auxiliary nets for loss calculation
|
| 149 |
if self.opt.chamfer_loss:
|
| 150 |
self.auxiliary_model_names += ['DT1', 'DT2']
|
|
|
|
| 156 |
if self.opt.test_continuity_loss:
|
| 157 |
self.auxiliary_model_names += ['Regressor']
|
| 158 |
if self.opt.use_local:
|
| 159 |
+
self.model_names += ['GLEyel', 'GLEyer', 'GLNose', 'GLMouth', 'GLHair', 'GLBG', 'GCombine']
|
| 160 |
+
self.auxiliary_model_names += ['CLm', 'CLh']
|
| 161 |
# auxiliary nets for local output refinement
|
| 162 |
if self.opt.nose_ae:
|
| 163 |
self.auxiliary_model_names += ['AE']
|
| 164 |
if self.opt.others_ae:
|
| 165 |
+
self.auxiliary_model_names += ['AEel', 'AEer', 'AEmowhite', 'AEmoblack']
|
| 166 |
print('model_names', self.model_names)
|
| 167 |
print('auxiliary_model_names', self.auxiliary_model_names)
|
| 168 |
# load/define networks
|
|
|
|
| 174 |
if self.isTrain:
|
| 175 |
use_sigmoid = opt.no_lsgan
|
| 176 |
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 177 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
|
| 178 |
+
self.gpu_ids)
|
| 179 |
print('netD', opt.netD, opt.n_layers_D)
|
| 180 |
if self.opt.discriminator_local:
|
| 181 |
self.netDLEyel = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 182 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
|
| 183 |
+
self.gpu_ids)
|
| 184 |
self.netDLEyer = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 185 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
|
| 186 |
+
self.gpu_ids)
|
| 187 |
self.netDLNose = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 188 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
|
| 189 |
+
self.gpu_ids)
|
| 190 |
self.netDLMouth = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 191 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
|
| 192 |
+
self.gpu_ids)
|
| 193 |
self.netDLHair = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 194 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
|
| 195 |
+
self.gpu_ids)
|
| 196 |
self.netDLBG = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
| 197 |
+
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain,
|
| 198 |
+
self.gpu_ids)
|
| 199 |
+
|
| 200 |
if self.opt.use_local:
|
| 201 |
netlocal1 = 'partunet' if self.opt.use_resnet == 0 else 'resnet_nblocks'
|
| 202 |
netlocal2 = 'partunet2' if self.opt.use_resnet == 0 else 'resnet_6blocks'
|
| 203 |
netlocal2_style = 'partunet2style' if self.opt.use_resnet == 0 else 'resnet_style2_6blocks'
|
| 204 |
self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
| 205 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
| 206 |
self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
| 207 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
| 208 |
self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
| 209 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
| 210 |
self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm,
|
| 211 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3)
|
| 212 |
self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2_style, opt.norm,
|
| 213 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4,
|
| 214 |
+
extra_channel=3)
|
| 215 |
self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2, opt.norm,
|
| 216 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4)
|
| 217 |
# by default combiner_type is combiner, which uses resnet
|
| 218 |
print('combiner_type', self.opt.combiner_type)
|
| 219 |
+
self.netGCombine = networks.define_G(2 * opt.output_nc, opt.output_nc, opt.ngf, self.opt.combiner_type,
|
| 220 |
+
opt.norm,
|
| 221 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2)
|
| 222 |
# auxiliary classifiers for mouth and hair
|
| 223 |
ratio = self.opt.fineSize / 256
|
| 224 |
self.MOUTH_H = int(self.opt.MOUTH_H * ratio)
|
| 225 |
self.MOUTH_W = int(self.opt.MOUTH_W * ratio)
|
| 226 |
self.netCLm = networks.define_G(opt.input_nc, 2, opt.ngf, 'classifier', opt.norm,
|
| 227 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
| 228 |
+
nnG=3, ae_h=self.MOUTH_H, ae_w=self.MOUTH_W)
|
| 229 |
self.netCLh = networks.define_G(opt.input_nc, 3, opt.ngf, 'classifier', opt.norm,
|
| 230 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
| 231 |
+
nnG=opt.nnG_hairc, ae_h=opt.fineSize, ae_w=opt.fineSize)
|
|
|
|
| 232 |
|
| 233 |
if self.isTrain:
|
| 234 |
self.fake_AB_pool = ImagePool(opt.pool_size)
|
|
|
|
| 241 |
if not self.opt.use_local:
|
| 242 |
print('G_params 1 components')
|
| 243 |
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
|
| 244 |
+
lr=opt.lr, betas=(opt.beta1, 0.999))
|
| 245 |
else:
|
| 246 |
+
G_params = list(self.netG.parameters()) + list(self.netGLEyel.parameters()) + list(
|
| 247 |
+
self.netGLEyer.parameters()) + list(self.netGLNose.parameters()) + list(
|
| 248 |
+
self.netGLMouth.parameters()) + list(self.netGCombine.parameters()) + list(
|
| 249 |
+
self.netGLHair.parameters()) + list(self.netGLBG.parameters())
|
| 250 |
print('G_params 8 components')
|
| 251 |
self.optimizer_G = torch.optim.Adam(G_params,
|
| 252 |
+
lr=opt.lr, betas=(opt.beta1, 0.999))
|
| 253 |
+
|
| 254 |
if not self.opt.discriminator_local:
|
| 255 |
print('D_params 1 components')
|
| 256 |
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
|
| 257 |
+
lr=opt.lr, betas=(opt.beta1, 0.999))
|
| 258 |
+
else: # self.opt.discriminator_local == True
|
| 259 |
+
D_params = list(self.netD.parameters()) + list(self.netDLEyel.parameters()) + list(
|
| 260 |
+
self.netDLEyer.parameters()) + list(self.netDLNose.parameters()) + list(
|
| 261 |
+
self.netDLMouth.parameters()) + list(self.netDLHair.parameters()) + list(self.netDLBG.parameters())
|
| 262 |
print('D_params 7 components')
|
| 263 |
self.optimizer_D = torch.optim.Adam(D_params,
|
| 264 |
+
lr=opt.lr, betas=(opt.beta1, 0.999))
|
| 265 |
self.optimizers.append(self.optimizer_G)
|
| 266 |
self.optimizers.append(self.optimizer_D)
|
| 267 |
+
|
| 268 |
# ==================================auxiliary nets (loaded, parameters fixed)=============================
|
| 269 |
if self.opt.use_local and self.opt.nose_ae:
|
| 270 |
ratio = self.opt.fineSize / 256
|
| 271 |
NOSE_H = self.opt.NOSE_H * ratio
|
| 272 |
NOSE_W = self.opt.NOSE_W * ratio
|
| 273 |
self.netAE = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 274 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
| 275 |
+
latent_dim=self.opt.ae_latentno, ae_h=NOSE_H, ae_w=NOSE_W)
|
| 276 |
+
self.set_requires_grad(self.netAE, False)
|
| 277 |
if self.opt.use_local and self.opt.others_ae:
|
| 278 |
ratio = self.opt.fineSize / 256
|
| 279 |
EYE_H = self.opt.EYE_H * ratio
|
|
|
|
| 281 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 282 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 283 |
self.netAEel = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 284 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
| 285 |
+
latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
|
| 286 |
self.netAEer = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 287 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
| 288 |
+
latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W)
|
| 289 |
self.netAEmowhite = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 290 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
| 291 |
+
latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
|
| 292 |
self.netAEmoblack = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch',
|
| 293 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
| 294 |
+
latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W)
|
| 295 |
self.set_requires_grad(self.netAEel, False)
|
| 296 |
self.set_requires_grad(self.netAEer, False)
|
| 297 |
self.set_requires_grad(self.netAEmowhite, False)
|
| 298 |
self.set_requires_grad(self.netAEmoblack, False)
|
| 299 |
+
|
|
|
|
| 300 |
if self.isTrain and self.opt.continuity_loss:
|
| 301 |
self.nc = 1
|
| 302 |
self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm,
|
| 303 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p,
|
| 304 |
+
nnG=opt.regarch)
|
| 305 |
self.set_requires_grad(self.netRegressor, False)
|
| 306 |
|
| 307 |
if self.isTrain and self.opt.chamfer_loss:
|
| 308 |
self.nc = 1
|
| 309 |
self.netDT1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm,
|
| 310 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
|
| 311 |
self.netDT2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm,
|
| 312 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
|
| 313 |
self.set_requires_grad(self.netDT1, False)
|
| 314 |
self.set_requires_grad(self.netDT2, False)
|
| 315 |
self.netLine1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm,
|
| 316 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
|
| 317 |
self.netLine2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm,
|
| 318 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p)
|
| 319 |
self.set_requires_grad(self.netLine1, False)
|
| 320 |
self.set_requires_grad(self.netLine2, False)
|
| 321 |
+
|
| 322 |
# ==================================for test (nets loaded, parameters fixed)=============================
|
| 323 |
+
if not self.isTrain and self.opt.test_continuity_loss:
|
| 324 |
self.nc = 1
|
| 325 |
self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm,
|
| 326 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids,
|
| 327 |
+
nnG=opt.regarch)
|
| 328 |
self.set_requires_grad(self.netRegressor, False)
|
|
|
|
| 329 |
|
| 330 |
def set_input(self, input):
|
| 331 |
AtoB = self.opt.which_direction == 'AtoB'
|
|
|
|
| 342 |
self.real_B_eyer = input['eyer_B'].to(self.device)
|
| 343 |
self.real_B_nose = input['nose_B'].to(self.device)
|
| 344 |
self.real_B_mouth = input['mouth_B'].to(self.device)
|
| 345 |
+
if self.opt.region_enm in [0, 1]:
|
| 346 |
self.center = input['center']
|
| 347 |
if self.opt.soft_border:
|
| 348 |
self.softel = input['soft_eyel_mask'].to(self.device)
|
|
|
|
| 351 |
self.softmo = input['soft_mouth_mask'].to(self.device)
|
| 352 |
if self.opt.compactmask:
|
| 353 |
self.cmask = input['cmask'].to(self.device)
|
| 354 |
+
self.cmask1 = self.cmask * 2 - 1 # [0,1]->[-1,1]
|
| 355 |
self.cmaskel = input['cmaskel'].to(self.device)
|
| 356 |
+
self.cmask1el = self.cmaskel * 2 - 1
|
| 357 |
self.cmasker = input['cmasker'].to(self.device)
|
| 358 |
+
self.cmask1er = self.cmasker * 2 - 1
|
| 359 |
self.cmaskmo = input['cmaskmo'].to(self.device)
|
| 360 |
+
self.cmask1mo = self.cmaskmo * 2 - 1
|
| 361 |
self.real_A_hair = input['hair_A'].to(self.device)
|
| 362 |
self.real_B_hair = input['hair_B'].to(self.device)
|
| 363 |
+
self.mask = input['mask'].to(self.device) # mask for non-eyes,nose,mouth
|
| 364 |
+
self.mask2 = input['mask2'].to(self.device) # mask for non-bg
|
| 365 |
self.real_A_bg = input['bg_A'].to(self.device)
|
| 366 |
self.real_B_bg = input['bg_B'].to(self.device)
|
| 367 |
if (self.isTrain and self.opt.chamfer_loss):
|
|
|
|
| 369 |
self.dt2gt = input['dt2gt'].to(self.device)
|
| 370 |
if self.isTrain and self.opt.emphasis_conti_face:
|
| 371 |
self.face_mask = input['face_mask'].cuda(self.gpu_ids_p[0])
|
| 372 |
+
|
| 373 |
+
def getonehot(self, outputs, classes):
|
| 374 |
+
[maxv, index] = torch.max(outputs, 1)
|
| 375 |
+
y = torch.unsqueeze(index, 1)
|
| 376 |
+
onehot = torch.FloatTensor(self.batch_size, classes).to(self.device)
|
| 377 |
onehot.zero_()
|
| 378 |
+
onehot.scatter_(1, y, 1)
|
| 379 |
return onehot
|
| 380 |
|
| 381 |
def forward(self):
|
|
|
|
| 385 |
self.fake_B0 = self.netG(self.real_A)
|
| 386 |
# EYES, MOUTH
|
| 387 |
outputs1 = self.netCLm(self.real_A_mouth)
|
| 388 |
+
onehot1 = self.getonehot(outputs1, 2)
|
| 389 |
|
| 390 |
if not self.opt.others_ae:
|
| 391 |
fake_B_eyel = self.netGLEyel(self.real_A_eyel)
|
| 392 |
fake_B_eyer = self.netGLEyer(self.real_A_eyer)
|
| 393 |
fake_B_mouth = self.netGLMouth(self.real_A_mouth)
|
| 394 |
+
else: # use AE that only constains compact region, need cmask!
|
| 395 |
self.fake_B_eyel1 = self.netGLEyel(self.real_A_eyel)
|
| 396 |
self.fake_B_eyer1 = self.netGLEyer(self.real_A_eyer)
|
| 397 |
self.fake_B_mouth1 = self.netGLMouth(self.real_A_mouth)
|
| 398 |
+
self.fake_B_eyel2, _ = self.netAEel(self.fake_B_eyel1)
|
| 399 |
+
self.fake_B_eyer2, _ = self.netAEer(self.fake_B_eyer1)
|
| 400 |
# USE 2 AEs
|
| 401 |
+
self.fake_B_mouth2 = torch.FloatTensor(self.batch_size, self.opt.output_nc, self.MOUTH_H,
|
| 402 |
+
self.MOUTH_W).to(self.device)
|
| 403 |
for i in range(self.batch_size):
|
| 404 |
if onehot1[i][0] == 1:
|
| 405 |
+
self.fake_B_mouth2[i], _ = self.netAEmowhite(self.fake_B_mouth1[i].unsqueeze(0))
|
| 406 |
+
# print('AEmowhite')
|
| 407 |
elif onehot1[i][1] == 1:
|
| 408 |
+
self.fake_B_mouth2[i], _ = self.netAEmoblack(self.fake_B_mouth1[i].unsqueeze(0))
|
| 409 |
+
# print('AEmoblack')
|
| 410 |
+
fake_B_eyel = self.add_with_mask(self.fake_B_eyel2, self.fake_B_eyel1, self.cmaskel)
|
| 411 |
+
fake_B_eyer = self.add_with_mask(self.fake_B_eyer2, self.fake_B_eyer1, self.cmasker)
|
| 412 |
+
fake_B_mouth = self.add_with_mask(self.fake_B_mouth2, self.fake_B_mouth1, self.cmaskmo)
|
| 413 |
# NOSE
|
| 414 |
if not self.opt.nose_ae:
|
| 415 |
fake_B_nose = self.netGLNose(self.real_A_nose)
|
| 416 |
+
else: # use AE that only constains compact region, need cmask!
|
| 417 |
self.fake_B_nose1 = self.netGLNose(self.real_A_nose)
|
| 418 |
+
self.fake_B_nose2, _ = self.netAE(self.fake_B_nose1)
|
| 419 |
+
fake_B_nose = self.add_with_mask(self.fake_B_nose2, self.fake_B_nose1, self.cmask)
|
| 420 |
|
| 421 |
# for visuals and later local loss
|
| 422 |
+
if self.opt.region_enm in [0, 1]:
|
| 423 |
self.fake_B_nose = fake_B_nose
|
| 424 |
self.fake_B_eyel = fake_B_eyel
|
| 425 |
self.fake_B_eyer = fake_B_eyer
|
|
|
|
| 430 |
self.fake_B_eyel = self.masked(fake_B_eyel, self.softel)
|
| 431 |
self.fake_B_eyer = self.masked(fake_B_eyer, self.softer)
|
| 432 |
self.fake_B_mouth = self.masked(fake_B_mouth, self.softmo)
|
| 433 |
+
elif self.opt.region_enm in [2]: # need to multiply cmask
|
| 434 |
+
self.fake_B_nose = self.masked(fake_B_nose, self.cmask)
|
| 435 |
+
self.fake_B_eyel = self.masked(fake_B_eyel, self.cmaskel)
|
| 436 |
+
self.fake_B_eyer = self.masked(fake_B_eyer, self.cmasker)
|
| 437 |
+
self.fake_B_mouth = self.masked(fake_B_mouth, self.cmaskmo)
|
| 438 |
+
|
| 439 |
# HAIR, BG AND PARTCOMBINE
|
| 440 |
outputs2 = self.netCLh(self.real_A_hair)
|
| 441 |
+
onehot2 = self.getonehot(outputs2, 3)
|
| 442 |
|
| 443 |
if not self.isTrain:
|
| 444 |
opt = self.opt
|
| 445 |
if opt.imagefolder == 'images':
|
| 446 |
+
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch),
|
| 447 |
+
'styleonehot.txt')
|
| 448 |
else:
|
| 449 |
+
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch),
|
| 450 |
+
opt.imagefolder, 'styleonehot.txt')
|
| 451 |
+
message = '%s [%d %d] [%d %d %d]' % (self.image_paths[0], onehot1[0][0], onehot1[0][1],
|
| 452 |
+
onehot2[0][0], onehot2[0][1], onehot2[0][2])
|
| 453 |
with open(file_name, 'a+') as s_file:
|
| 454 |
s_file.write(message)
|
| 455 |
s_file.write('\n')
|
| 456 |
|
| 457 |
+
fake_B_hair = self.netGLHair(self.real_A_hair, onehot2)
|
| 458 |
fake_B_bg = self.netGLBG(self.real_A_bg)
|
| 459 |
+
self.fake_B_hair = self.masked(fake_B_hair, self.mask * self.mask2)
|
| 460 |
+
self.fake_B_bg = self.masked(fake_B_bg, self.inverse_mask(self.mask2))
|
| 461 |
if not self.opt.compactmask:
|
| 462 |
+
self.fake_B1 = self.partCombiner2_bg(fake_B_eyel, fake_B_eyer, fake_B_nose, fake_B_mouth, fake_B_hair,
|
| 463 |
+
fake_B_bg, self.mask * self.mask2, self.inverse_mask(self.mask2),
|
| 464 |
+
self.opt.comb_op)
|
| 465 |
else:
|
| 466 |
+
self.fake_B1 = self.partCombiner2_bg(fake_B_eyel, fake_B_eyer, fake_B_nose, fake_B_mouth, fake_B_hair,
|
| 467 |
+
fake_B_bg, self.mask * self.mask2, self.inverse_mask(self.mask2),
|
| 468 |
+
self.opt.comb_op, self.opt.region_enm, self.cmaskel, self.cmasker,
|
| 469 |
+
self.cmask, self.cmaskmo)
|
| 470 |
+
|
| 471 |
+
self.fake_B = self.netGCombine(torch.cat([self.fake_B0, self.fake_B1], 1))
|
| 472 |
|
| 473 |
# for AE visuals
|
| 474 |
+
if self.opt.region_enm in [0, 1]:
|
| 475 |
if self.opt.nose_ae:
|
| 476 |
self.fake_B_nose_v = padpart(self.fake_B_nose, 'nose', self.center, self.opt, self.device)
|
| 477 |
self.fake_B_nose_v1 = padpart(self.fake_B_nose1, 'nose', self.center, self.opt, self.device)
|
|
|
|
| 490 |
self.fake_B_mouth_v1 = padpart(self.fake_B_mouth1, 'mouth', self.center, self.opt, self.device)
|
| 491 |
self.fake_B_mouth_v2 = padpart(self.fake_B_mouth2, 'mouth', self.center, self.opt, self.device)
|
| 492 |
self.cmask1mo = padpart(self.cmask1mo, 'mouth', self.center, self.opt, self.device)
|
| 493 |
+
|
| 494 |
if not self.isTrain and self.opt.test_continuity_loss:
|
| 495 |
self.ContinuityForTest(real=1)
|
| 496 |
+
|
|
|
|
| 497 |
def backward_D(self):
|
| 498 |
# Fake
|
| 499 |
# stop backprop to the generator by detaching fake_B
|
| 500 |
fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
|
| 501 |
+
# print('fake_AB', fake_AB.shape) # (1,4,512,512)
|
| 502 |
+
pred_fake = self.netD(fake_AB.detach()) # by detach, not affect G's gradient
|
| 503 |
self.loss_D_fake = self.criterionGAN(pred_fake, False)
|
| 504 |
if self.opt.discriminator_local:
|
| 505 |
fake_AB_parts = self.getLocalParts(fake_AB)
|
| 506 |
+
local_names = ['DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG']
|
| 507 |
self.loss_D_fake_local = 0
|
| 508 |
for i in range(len(fake_AB_parts)):
|
| 509 |
net = getattr(self, 'net' + local_names[i])
|
|
|
|
| 518 |
self.loss_D_real = self.criterionGAN(pred_real, True)
|
| 519 |
if self.opt.discriminator_local:
|
| 520 |
real_AB_parts = self.getLocalParts(real_AB)
|
| 521 |
+
local_names = ['DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG']
|
| 522 |
self.loss_D_real_local = 0
|
| 523 |
for i in range(len(real_AB_parts)):
|
| 524 |
net = getattr(self, 'net' + local_names[i])
|
|
|
|
| 535 |
def backward_G(self):
|
| 536 |
# First, G(A) should fake the discriminator
|
| 537 |
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
|
| 538 |
+
pred_fake = self.netD(fake_AB) # (1,4,512,512)->(1,1,30,30)
|
| 539 |
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
|
| 540 |
if self.opt.discriminator_local:
|
| 541 |
fake_AB_parts = self.getLocalParts(fake_AB)
|
| 542 |
+
local_names = ['DLEyel', 'DLEyer', 'DLNose', 'DLMouth', 'DLHair', 'DLBG']
|
| 543 |
+
self.loss_G_GAN_local = 0 # G_GAN_local is then added into G_GAN
|
| 544 |
for i in range(len(fake_AB_parts)):
|
| 545 |
net = getattr(self, 'net' + local_names[i])
|
| 546 |
pred_fake_tmp = net(fake_AB_parts[i])
|
|
|
|
| 555 |
# Second, G(A) = B
|
| 556 |
if not self.opt.no_l1_loss:
|
| 557 |
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
|
| 558 |
+
|
| 559 |
if self.opt.use_local and not self.opt.no_G_local_loss:
|
| 560 |
+
local_names = ['eyel', 'eyer', 'nose', 'mouth']
|
| 561 |
self.loss_G_local = 0
|
| 562 |
for i in range(len(local_names)):
|
| 563 |
fakeblocal = getattr(self, 'fake_B_' + local_names[i])
|
| 564 |
realblocal = getattr(self, 'real_B_' + local_names[i])
|
| 565 |
addw = self.getaddw(local_names[i])
|
| 566 |
+
self.loss_G_local = self.loss_G_local + self.criterionL1(fakeblocal,
|
| 567 |
+
realblocal) * self.opt.lambda_local * addw
|
| 568 |
+
self.loss_G_hair_local = self.criterionL1(self.fake_B_hair,
|
| 569 |
+
self.real_B_hair) * self.opt.lambda_local * self.opt.addw_hair
|
| 570 |
+
self.loss_G_bg_local = self.criterionL1(self.fake_B_bg,
|
| 571 |
+
self.real_B_bg) * self.opt.lambda_local * self.opt.addw_bg
|
| 572 |
|
| 573 |
# Third, chamfer matching (assume chamfer_2way and chamfer_only_line is true)
|
| 574 |
if self.opt.chamfer_loss:
|
| 575 |
if self.fake_B.shape[1] == 3:
|
| 576 |
+
tmp = self.fake_B[:, 0, ...] * 0.299 + self.fake_B[:, 1, ...] * 0.587 + self.fake_B[:, 2, ...] * 0.114
|
| 577 |
fake_B_gray = tmp.unsqueeze(1)
|
| 578 |
else:
|
| 579 |
fake_B_gray = self.fake_B
|
| 580 |
if self.real_B.shape[1] == 3:
|
| 581 |
+
tmp = self.real_B[:, 0, ...] * 0.299 + self.real_B[:, 1, ...] * 0.587 + self.real_B[:, 2, ...] * 0.114
|
| 582 |
real_B_gray = tmp.unsqueeze(1)
|
| 583 |
else:
|
| 584 |
real_B_gray = self.real_B
|
| 585 |
+
|
| 586 |
gpu_p = self.opt.gpu_ids_p[0]
|
| 587 |
gpu = self.opt.gpu_ids[0]
|
| 588 |
if gpu_p != gpu:
|
|
|
|
| 592 |
# d_CM(a_i,G(p_i))
|
| 593 |
self.dt1 = self.netDT1(fake_B_gray)
|
| 594 |
self.dt2 = self.netDT2(fake_B_gray)
|
| 595 |
+
dt1 = self.dt1 / 2.0 + 0.5 # [-1,1]->[0,1]
|
| 596 |
+
dt2 = self.dt2 / 2.0 + 0.5
|
| 597 |
if self.opt.dt_nonlinear != '':
|
| 598 |
dt_xmax = torch.Tensor([self.opt.dt_xmax]).cuda(gpu_p)
|
| 599 |
dt1 = nonlinearDt(dt1, self.opt.dt_nonlinear, dt_xmax)
|
| 600 |
dt2 = nonlinearDt(dt2, self.opt.dt_nonlinear, dt_xmax)
|
| 601 |
+
# print('dt1dt2',torch.min(dt1).item(),torch.max(dt1).item(),torch.min(dt2).item(),torch.max(dt2).item())
|
| 602 |
+
|
| 603 |
bs = real_B_gray.shape[0]
|
| 604 |
real_B_gray_line1 = self.netLine1(real_B_gray)
|
| 605 |
real_B_gray_line2 = self.netLine2(real_B_gray)
|
| 606 |
+
self.loss_G_chamfer = (dt1[(real_B_gray < 0) & (real_B_gray_line1 < 0)].sum() + dt2[
|
| 607 |
+
(real_B_gray >= 0) & (real_B_gray_line2 >= 0)].sum()) / bs * self.opt.lambda_chamfer
|
| 608 |
if gpu_p != gpu:
|
| 609 |
+
self.loss_G_chamfer = self.loss_G_chamfer.cuda(gpu)
|
| 610 |
|
| 611 |
+
# d_CM(G(p_i),a_i)
|
| 612 |
if gpu_p != gpu:
|
| 613 |
dt1gt = self.dt1gt.cuda(gpu_p)
|
| 614 |
dt2gt = self.dt2gt.cuda(gpu_p)
|
|
|
|
| 618 |
if self.opt.dt_nonlinear != '':
|
| 619 |
dt1gt = nonlinearDt(dt1gt, self.opt.dt_nonlinear, dt_xmax)
|
| 620 |
dt2gt = nonlinearDt(dt2gt, self.opt.dt_nonlinear, dt_xmax)
|
| 621 |
+
# print('dt1gtdt2gt',torch.min(dt1gt).item(),torch.max(dt1gt).item(),torch.min(dt2gt).item(),torch.max(dt2gt).item())
|
| 622 |
+
self.dt1gt = (self.dt1gt - 0.5) * 2
|
| 623 |
+
self.dt2gt = (self.dt2gt - 0.5) * 2
|
| 624 |
|
| 625 |
fake_B_gray_line1 = self.netLine1(fake_B_gray)
|
| 626 |
fake_B_gray_line2 = self.netLine2(fake_B_gray)
|
| 627 |
+
self.loss_G_chamfer2 = (dt1gt[(fake_B_gray < 0) & (fake_B_gray_line1 < 0)].sum() + dt2gt[
|
| 628 |
+
(fake_B_gray >= 0) & (fake_B_gray_line2 >= 0)].sum()) / bs * self.opt.lambda_chamfer2
|
| 629 |
if gpu_p != gpu:
|
| 630 |
self.loss_G_chamfer2 = self.loss_G_chamfer2.cuda(gpu)
|
| 631 |
|
|
|
|
| 635 |
self.get_patches()
|
| 636 |
self.outputs = self.netRegressor(self.fake_B_patches)
|
| 637 |
if not self.opt.emphasis_conti_face:
|
| 638 |
+
self.loss_G_continuity = (1.0 - torch.mean(self.outputs)).cuda(gpu) * self.opt.lambda_continuity
|
| 639 |
else:
|
| 640 |
+
self.loss_G_continuity = torch.mean((1.0 - self.outputs) * self.conti_weights).cuda(
|
| 641 |
+
gpu) * self.opt.lambda_continuity
|
|
|
|
| 642 |
|
| 643 |
self.loss_G = self.loss_G_GAN
|
| 644 |
if 'G_L1' in self.loss_names:
|
|
|
|
| 662 |
self.forward()
|
| 663 |
# update D
|
| 664 |
self.set_requires_grad(self.netD, True)
|
| 665 |
+
|
| 666 |
if self.opt.discriminator_local:
|
| 667 |
self.set_requires_grad(self.netDLEyel, True)
|
| 668 |
self.set_requires_grad(self.netDLEyer, True)
|
|
|
|
| 696 |
patches = []
|
| 697 |
if self.isTrain and self.opt.emphasis_conti_face:
|
| 698 |
weights = []
|
| 699 |
+
W2 = int(W / 2)
|
| 700 |
+
t = np.random.randint(res, size=2)
|
| 701 |
for i in range(aa):
|
| 702 |
for j in range(aa):
|
| 703 |
+
p = self.fake_B[:, :, t[0] + i * W:t[0] + (i + 1) * W, t[1] + j * W:t[1] + (j + 1) * W]
|
| 704 |
+
whitenum = torch.sum(p >= 0.0)
|
| 705 |
+
# if whitenum < 5 or whitenum > W*W-5:
|
| 706 |
+
if whitenum < 1 or whitenum > W * W - 1:
|
| 707 |
continue
|
| 708 |
patches.append(p)
|
| 709 |
if self.isTrain and self.opt.emphasis_conti_face:
|
| 710 |
+
weights.append(self.face_mask[:, :, t[0] + i * W + W2, t[1] + j * W + W2])
|
| 711 |
self.fake_B_patches = torch.cat(patches, dim=0)
|
| 712 |
if self.isTrain and self.opt.emphasis_conti_face:
|
| 713 |
+
self.conti_weights = torch.cat(weights, dim=0) + 1 # 0->1,1->2
|
| 714 |
+
|
| 715 |
def get_patches_real(self):
|
| 716 |
# [1,1,512,512]->[bs,1,11,11]
|
| 717 |
patches = []
|
| 718 |
+
t = np.random.randint(res, size=2)
|
| 719 |
for i in range(aa):
|
| 720 |
for j in range(aa):
|
| 721 |
+
p = self.real_B[:, :, t[0] + i * W:t[0] + (i + 1) * W, t[1] + j * W:t[1] + (j + 1) * W]
|
| 722 |
+
whitenum = torch.sum(p >= 0.0)
|
| 723 |
+
# if whitenum < 5 or whitenum > W*W-5:
|
| 724 |
+
if whitenum < 1 or whitenum > W * W - 1:
|
| 725 |
continue
|
| 726 |
patches.append(p)
|
| 727 |
self.real_B_patches = torch.cat(patches, dim=0)
|
APDrawingGAN2/models/base_model.py
CHANGED
|
@@ -20,8 +20,8 @@ class BaseModel():
|
|
| 20 |
self.gpu_ids = opt.gpu_ids
|
| 21 |
self.gpu_ids_p = opt.gpu_ids_p
|
| 22 |
self.isTrain = opt.isTrain
|
| 23 |
-
self.device = torch.device('cpu')
|
| 24 |
-
self.device_p = torch.device('cpu')
|
| 25 |
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 26 |
self.auxiliary_dir = os.path.join(opt.checkpoints_dir, opt.auxiliary_root)
|
| 27 |
if opt.resize_or_crop != 'scale_width':
|
|
@@ -105,7 +105,7 @@ class BaseModel():
|
|
| 105 |
net.cuda(self.gpu_ids[0])
|
| 106 |
else:
|
| 107 |
torch.save(net.cpu().state_dict(), save_path)
|
| 108 |
-
|
| 109 |
def save_networks2(self, which_epoch):
|
| 110 |
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
|
| 111 |
dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch))
|
|
@@ -120,7 +120,7 @@ class BaseModel():
|
|
| 120 |
net.cuda(self.gpu_ids[0])
|
| 121 |
else:
|
| 122 |
state_dict = net.cpu().state_dict()
|
| 123 |
-
|
| 124 |
if name[0] == 'G':
|
| 125 |
dict_gen[name] = state_dict
|
| 126 |
elif name[0] == 'D':
|
|
@@ -142,7 +142,7 @@ class BaseModel():
|
|
| 142 |
if getattr(module, key) is None:
|
| 143 |
state_dict.pop('.'.join(keys))
|
| 144 |
if module.__class__.__name__.startswith('InstanceNorm') and \
|
| 145 |
-
|
| 146 |
state_dict.pop('.'.join(keys))
|
| 147 |
else:
|
| 148 |
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
|
@@ -171,7 +171,7 @@ class BaseModel():
|
|
| 171 |
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
| 172 |
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
| 173 |
net.load_state_dict(state_dict)
|
| 174 |
-
|
| 175 |
def load_networks2(self, which_epoch):
|
| 176 |
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
|
| 177 |
gen_state_dict = torch.load(gen_name, map_location=str(self.device))
|
|
@@ -184,19 +184,19 @@ class BaseModel():
|
|
| 184 |
if isinstance(net, torch.nn.DataParallel):
|
| 185 |
net = net.module
|
| 186 |
if name[0] == 'G':
|
| 187 |
-
print('loading the model %s from %s' % (name,gen_name))
|
| 188 |
state_dict = gen_state_dict[name]
|
| 189 |
elif name[0] == 'D':
|
| 190 |
-
print('loading the model %s from %s' % (name,gen_name))
|
| 191 |
state_dict = dis_state_dict[name]
|
| 192 |
-
|
| 193 |
if hasattr(state_dict, '_metadata'):
|
| 194 |
del state_dict._metadata
|
| 195 |
# patch InstanceNorm checkpoints prior to 0.4
|
| 196 |
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
| 197 |
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
| 198 |
net.load_state_dict(state_dict)
|
| 199 |
-
|
| 200 |
# load auxiliary net models from the disk
|
| 201 |
def load_auxiliary_networks(self):
|
| 202 |
for name in self.auxiliary_model_names:
|
|
@@ -214,7 +214,8 @@ class BaseModel():
|
|
| 214 |
print('loading the model from %s' % load_path)
|
| 215 |
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 216 |
# GitHub source), you can remove str() on self.device
|
| 217 |
-
if name in ['DT1', 'DT2', 'Line1', 'Line2', 'Continuity1', 'Continuity2', 'Regressor', 'Regressorhair',
|
|
|
|
| 218 |
state_dict = torch.load(load_path, map_location=str(self.device_p))
|
| 219 |
else:
|
| 220 |
state_dict = torch.load(load_path, map_location=str(self.device))
|
|
@@ -251,18 +252,19 @@ class BaseModel():
|
|
| 251 |
|
| 252 |
# =============================================================================================================
|
| 253 |
def inverse_mask(self, mask):
|
| 254 |
-
return torch.ones(mask.shape).to(self.device)-mask
|
| 255 |
-
|
| 256 |
-
def masked(self, A,mask):
|
| 257 |
-
return (A/2+0.5)*mask*2-1
|
| 258 |
-
|
| 259 |
-
def add_with_mask(self, A,B,mask):
|
| 260 |
-
return ((A/2+0.5)*mask+(B/2+0.5)*(torch.ones(mask.shape).to(self.device)-mask))*2-1
|
| 261 |
-
|
| 262 |
-
def addone_with_mask(self, A,mask):
|
| 263 |
-
return ((A/2+0.5)*mask+(torch.ones(mask.shape).to(self.device)-mask))*2-1
|
| 264 |
-
|
| 265 |
-
def partCombiner(self, eyel, eyer, nose, mouth, average_pos=False, comb_op
|
|
|
|
| 266 |
'''
|
| 267 |
x y
|
| 268 |
100.571 123.429
|
|
@@ -276,7 +278,7 @@ class BaseModel():
|
|
| 276 |
if comb_op == 0:
|
| 277 |
# use max pooling, pad black for eyes etc
|
| 278 |
padvalue = -1
|
| 279 |
-
if region_enm in [1,2]:
|
| 280 |
eyel = eyel * cmaskel
|
| 281 |
eyer = eyer * cmasker
|
| 282 |
nose = nose * cmaskno
|
|
@@ -284,12 +286,12 @@ class BaseModel():
|
|
| 284 |
else:
|
| 285 |
# use min pooling, pad white for eyes etc
|
| 286 |
padvalue = 1
|
| 287 |
-
if region_enm in [1,2]:
|
| 288 |
eyel = self.addone_with_mask(eyel, cmaskel)
|
| 289 |
eyer = self.addone_with_mask(eyer, cmasker)
|
| 290 |
nose = self.addone_with_mask(nose, cmaskno)
|
| 291 |
mouth = self.addone_with_mask(mouth, cmaskmo)
|
| 292 |
-
if region_enm in [0,1]:
|
| 293 |
IMAGE_SIZE = self.opt.fineSize
|
| 294 |
ratio = IMAGE_SIZE / 256
|
| 295 |
EYE_W = self.opt.EYE_W * ratio
|
|
@@ -298,20 +300,32 @@ class BaseModel():
|
|
| 298 |
NOSE_H = self.opt.NOSE_H * ratio
|
| 299 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 300 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 301 |
-
bs,nc,_,_ = eyel.shape
|
| 302 |
-
eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 303 |
-
eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 304 |
-
nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 305 |
-
mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 306 |
for i in range(bs):
|
| 307 |
if not average_pos:
|
| 308 |
-
center = self.center[i]#x,y
|
| 309 |
-
else
|
| 310 |
-
center = torch.tensor([[101,123-4],[155,123-4],[128,156-NOSE_H/2+16],[128,185]])
|
| 311 |
-
eyel_p[i] = torch.nn.ConstantPad2d((int(center[0,0] - EYE_W / 2 - 1),
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
elif region_enm in [2]:
|
| 316 |
eyel_p = eyel
|
| 317 |
eyer_p = eyer
|
|
@@ -328,13 +342,14 @@ class BaseModel():
|
|
| 328 |
eye_nose = torch.min(eyes, nose_p)
|
| 329 |
result = torch.min(eye_nose, mouth_p)
|
| 330 |
return result
|
| 331 |
-
|
| 332 |
-
def partCombiner2(self, eyel, eyer, nose, mouth, hair, mask, comb_op
|
|
|
|
| 333 |
if comb_op == 0:
|
| 334 |
# use max pooling, pad black for eyes etc
|
| 335 |
padvalue = -1
|
| 336 |
hair = self.masked(hair, mask)
|
| 337 |
-
if region_enm in [1,2]:
|
| 338 |
eyel = eyel * cmaskel
|
| 339 |
eyer = eyer * cmasker
|
| 340 |
nose = nose * cmaskno
|
|
@@ -343,12 +358,12 @@ class BaseModel():
|
|
| 343 |
# use min pooling, pad white for eyes etc
|
| 344 |
padvalue = 1
|
| 345 |
hair = self.addone_with_mask(hair, mask)
|
| 346 |
-
if region_enm in [1,2]:
|
| 347 |
eyel = self.addone_with_mask(eyel, cmaskel)
|
| 348 |
eyer = self.addone_with_mask(eyer, cmasker)
|
| 349 |
nose = self.addone_with_mask(nose, cmaskno)
|
| 350 |
mouth = self.addone_with_mask(mouth, cmaskmo)
|
| 351 |
-
if region_enm in [0,1]:
|
| 352 |
IMAGE_SIZE = self.opt.fineSize
|
| 353 |
ratio = IMAGE_SIZE / 256
|
| 354 |
EYE_W = self.opt.EYE_W * ratio
|
|
@@ -357,17 +372,26 @@ class BaseModel():
|
|
| 357 |
NOSE_H = self.opt.NOSE_H * ratio
|
| 358 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 359 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 360 |
-
bs,nc,_,_ = eyel.shape
|
| 361 |
-
eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 362 |
-
eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 363 |
-
nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 364 |
-
mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 365 |
for i in range(bs):
|
| 366 |
-
center = self.center[i]#x,y
|
| 367 |
-
eyel_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
elif region_enm in [2]:
|
| 372 |
eyel_p = eyel
|
| 373 |
eyer_p = eyer
|
|
@@ -378,22 +402,23 @@ class BaseModel():
|
|
| 378 |
eyes = torch.max(eyel_p, eyer_p)
|
| 379 |
eye_nose = torch.max(eyes, nose_p)
|
| 380 |
eye_nose_mouth = torch.max(eye_nose, mouth_p)
|
| 381 |
-
result = torch.max(hair,eye_nose_mouth)
|
| 382 |
else:
|
| 383 |
# use min pooling
|
| 384 |
eyes = torch.min(eyel_p, eyer_p)
|
| 385 |
eye_nose = torch.min(eyes, nose_p)
|
| 386 |
eye_nose_mouth = torch.min(eye_nose, mouth_p)
|
| 387 |
-
result = torch.min(hair,eye_nose_mouth)
|
| 388 |
return result
|
| 389 |
-
|
| 390 |
-
def partCombiner2_bg(self, eyel, eyer, nose, mouth, hair, bg, maskh, maskb, comb_op
|
|
|
|
| 391 |
if comb_op == 0:
|
| 392 |
# use max pooling, pad black for eyes etc
|
| 393 |
padvalue = -1
|
| 394 |
hair = self.masked(hair, maskh)
|
| 395 |
bg = self.masked(bg, maskb)
|
| 396 |
-
if region_enm in [1,2]:
|
| 397 |
eyel = eyel * cmaskel
|
| 398 |
eyer = eyer * cmasker
|
| 399 |
nose = nose * cmaskno
|
|
@@ -403,12 +428,12 @@ class BaseModel():
|
|
| 403 |
padvalue = 1
|
| 404 |
hair = self.addone_with_mask(hair, maskh)
|
| 405 |
bg = self.addone_with_mask(bg, maskb)
|
| 406 |
-
if region_enm in [1,2]:
|
| 407 |
eyel = self.addone_with_mask(eyel, cmaskel)
|
| 408 |
eyer = self.addone_with_mask(eyer, cmasker)
|
| 409 |
nose = self.addone_with_mask(nose, cmaskno)
|
| 410 |
mouth = self.addone_with_mask(mouth, cmaskmo)
|
| 411 |
-
if region_enm in [0,1]:
|
| 412 |
IMAGE_SIZE = self.opt.fineSize
|
| 413 |
ratio = IMAGE_SIZE / 256
|
| 414 |
EYE_W = self.opt.EYE_W * ratio
|
|
@@ -417,17 +442,29 @@ class BaseModel():
|
|
| 417 |
NOSE_H = self.opt.NOSE_H * ratio
|
| 418 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 419 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 420 |
-
bs,nc,_,_ = eyel.shape
|
| 421 |
-
eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 422 |
-
eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 423 |
-
nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 424 |
-
mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device)
|
| 425 |
for i in range(bs):
|
| 426 |
-
center = self.center[i]#x,y
|
| 427 |
-
eyel_p[i] = torch.nn.ConstantPad2d((
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
elif region_enm in [2]:
|
| 432 |
eyel_p = eyel
|
| 433 |
eyer_p = eyer
|
|
@@ -437,17 +474,17 @@ class BaseModel():
|
|
| 437 |
eyes = torch.max(eyel_p, eyer_p)
|
| 438 |
eye_nose = torch.max(eyes, nose_p)
|
| 439 |
eye_nose_mouth = torch.max(eye_nose, mouth_p)
|
| 440 |
-
eye_nose_mouth_hair = torch.max(hair,eye_nose_mouth)
|
| 441 |
-
result = torch.max(bg,eye_nose_mouth_hair)
|
| 442 |
else:
|
| 443 |
eyes = torch.min(eyel_p, eyer_p)
|
| 444 |
eye_nose = torch.min(eyes, nose_p)
|
| 445 |
eye_nose_mouth = torch.min(eye_nose, mouth_p)
|
| 446 |
-
eye_nose_mouth_hair = torch.min(hair,eye_nose_mouth)
|
| 447 |
-
result = torch.min(bg,eye_nose_mouth_hair)
|
| 448 |
return result
|
| 449 |
-
|
| 450 |
-
def partCombiner3(self, face, hair, maskf, maskh, comb_op
|
| 451 |
if comb_op == 0:
|
| 452 |
# use max pooling, pad black etc
|
| 453 |
padvalue = -1
|
|
@@ -459,27 +496,25 @@ class BaseModel():
|
|
| 459 |
face = self.addone_with_mask(face, maskf)
|
| 460 |
hair = self.addone_with_mask(hair, maskh)
|
| 461 |
if comb_op == 0:
|
| 462 |
-
result = torch.max(face,hair)
|
| 463 |
else:
|
| 464 |
-
result = torch.min(face,hair)
|
| 465 |
return result
|
| 466 |
|
| 467 |
-
|
| 468 |
def tocv2(ts):
|
| 469 |
-
img = (ts.numpy()/2+0.5)*255
|
| 470 |
img = img.astype('uint8')
|
| 471 |
-
img = np.transpose(img,(1,2,0))
|
| 472 |
-
img = img[
|
| 473 |
return img
|
| 474 |
-
|
| 475 |
def totor(img):
|
| 476 |
-
img = img[
|
| 477 |
tor = transforms.ToTensor()(img)
|
| 478 |
tor = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tor)
|
| 479 |
return tor
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
def ContinuityForTest(self, real = 0):
|
| 483 |
# Patch-based
|
| 484 |
self.get_patches()
|
| 485 |
self.outputs = self.netRegressor(self.fake_B_patches)
|
|
@@ -494,16 +529,17 @@ class BaseModel():
|
|
| 494 |
self.get_patches_real()
|
| 495 |
self.outputs2 = self.netRegressor(self.real_B_patches)
|
| 496 |
line_continuity2 = torch.mean(self.outputs2)
|
| 497 |
-
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch),
|
|
|
|
| 498 |
message = '%s %.04f' % (self.image_paths[0], line_continuity2)
|
| 499 |
with open(file_name, 'a+') as c_file:
|
| 500 |
c_file.write(message)
|
| 501 |
c_file.write('\n')
|
| 502 |
-
|
| 503 |
-
def getLocalParts(self,fakeAB):
|
| 504 |
-
bs,nc,_,_ = fakeAB.shape
|
| 505 |
ncr = int(nc / self.opt.output_nc)
|
| 506 |
-
if self.opt.region_enm in [0,1]:
|
| 507 |
ratio = self.opt.fineSize / 256
|
| 508 |
EYE_H = self.opt.EYE_H * ratio
|
| 509 |
EYE_W = self.opt.EYE_W * ratio
|
|
@@ -511,28 +547,32 @@ class BaseModel():
|
|
| 511 |
NOSE_W = self.opt.NOSE_W * ratio
|
| 512 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 513 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 514 |
-
eyel = torch.ones((bs,nc,int(EYE_H),int(EYE_W))).to(self.device)
|
| 515 |
-
eyer = torch.ones((bs,nc,int(EYE_H),int(EYE_W))).to(self.device)
|
| 516 |
-
nose = torch.ones((bs,nc,int(NOSE_H),int(NOSE_W))).to(self.device)
|
| 517 |
-
mouth = torch.ones((bs,nc,int(MOUTH_H),int(MOUTH_W))).to(self.device)
|
| 518 |
for i in range(bs):
|
| 519 |
center = self.center[i]
|
| 520 |
-
eyel[i] = fakeAB[i
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
elif self.opt.region_enm in [2]:
|
| 525 |
-
eyel = (fakeAB/2+0.5) * self.cmaskel.repeat(1,ncr,1,1) * 2 - 1
|
| 526 |
-
eyer = (fakeAB/2+0.5) * self.cmasker.repeat(1,ncr,1,1) * 2 - 1
|
| 527 |
-
nose = (fakeAB/2+0.5) * self.cmask.repeat(1,ncr,1,1) * 2 - 1
|
| 528 |
-
mouth = (fakeAB/2+0.5) * self.cmaskmo.repeat(1,ncr,1,1) * 2 - 1
|
| 529 |
-
hair = (fakeAB/2+0.5) * self.mask.repeat(1,ncr,1,1) * self.mask2.repeat(1,ncr,1,1) * 2 - 1
|
| 530 |
-
bg = (fakeAB/2+0.5) * (torch.ones(fakeAB.shape).to(self.device)-self.mask2.repeat(1,ncr,1,1)) * 2 - 1
|
| 531 |
return eyel, eyer, nose, mouth, hair, bg
|
| 532 |
-
|
| 533 |
-
def getaddw(self,local_name):
|
| 534 |
addw = 1
|
| 535 |
-
if local_name in ['DLEyel','DLEyer','eyel','eyer','DLFace','face']:
|
| 536 |
addw = self.opt.addw_eye
|
| 537 |
elif local_name in ['DLNose', 'nose']:
|
| 538 |
addw = self.opt.addw_nose
|
|
|
|
| 20 |
self.gpu_ids = opt.gpu_ids
|
| 21 |
self.gpu_ids_p = opt.gpu_ids_p
|
| 22 |
self.isTrain = opt.isTrain
|
| 23 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
| 24 |
+
self.device_p = torch.device('cuda:{}'.format(self.gpu_ids_p[0])) if self.gpu_ids else torch.device('cpu')
|
| 25 |
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
| 26 |
self.auxiliary_dir = os.path.join(opt.checkpoints_dir, opt.auxiliary_root)
|
| 27 |
if opt.resize_or_crop != 'scale_width':
|
|
|
|
| 105 |
net.cuda(self.gpu_ids[0])
|
| 106 |
else:
|
| 107 |
torch.save(net.cpu().state_dict(), save_path)
|
| 108 |
+
|
| 109 |
def save_networks2(self, which_epoch):
|
| 110 |
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
|
| 111 |
dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch))
|
|
|
|
| 120 |
net.cuda(self.gpu_ids[0])
|
| 121 |
else:
|
| 122 |
state_dict = net.cpu().state_dict()
|
| 123 |
+
|
| 124 |
if name[0] == 'G':
|
| 125 |
dict_gen[name] = state_dict
|
| 126 |
elif name[0] == 'D':
|
|
|
|
| 142 |
if getattr(module, key) is None:
|
| 143 |
state_dict.pop('.'.join(keys))
|
| 144 |
if module.__class__.__name__.startswith('InstanceNorm') and \
|
| 145 |
+
(key == 'num_batches_tracked'):
|
| 146 |
state_dict.pop('.'.join(keys))
|
| 147 |
else:
|
| 148 |
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
|
|
|
| 171 |
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
| 172 |
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
| 173 |
net.load_state_dict(state_dict)
|
| 174 |
+
|
| 175 |
def load_networks2(self, which_epoch):
|
| 176 |
gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch))
|
| 177 |
gen_state_dict = torch.load(gen_name, map_location=str(self.device))
|
|
|
|
| 184 |
if isinstance(net, torch.nn.DataParallel):
|
| 185 |
net = net.module
|
| 186 |
if name[0] == 'G':
|
| 187 |
+
print('loading the model %s from %s' % (name, gen_name))
|
| 188 |
state_dict = gen_state_dict[name]
|
| 189 |
elif name[0] == 'D':
|
| 190 |
+
print('loading the model %s from %s' % (name, gen_name))
|
| 191 |
state_dict = dis_state_dict[name]
|
| 192 |
+
|
| 193 |
if hasattr(state_dict, '_metadata'):
|
| 194 |
del state_dict._metadata
|
| 195 |
# patch InstanceNorm checkpoints prior to 0.4
|
| 196 |
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
| 197 |
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
| 198 |
net.load_state_dict(state_dict)
|
| 199 |
+
|
| 200 |
# load auxiliary net models from the disk
|
| 201 |
def load_auxiliary_networks(self):
|
| 202 |
for name in self.auxiliary_model_names:
|
|
|
|
| 214 |
print('loading the model from %s' % load_path)
|
| 215 |
# if you are using PyTorch newer than 0.4 (e.g., built from
|
| 216 |
# GitHub source), you can remove str() on self.device
|
| 217 |
+
if name in ['DT1', 'DT2', 'Line1', 'Line2', 'Continuity1', 'Continuity2', 'Regressor', 'Regressorhair',
|
| 218 |
+
'Regressorface']:
|
| 219 |
state_dict = torch.load(load_path, map_location=str(self.device_p))
|
| 220 |
else:
|
| 221 |
state_dict = torch.load(load_path, map_location=str(self.device))
|
|
|
|
| 252 |
|
| 253 |
# =============================================================================================================
|
| 254 |
def inverse_mask(self, mask):
|
| 255 |
+
return torch.ones(mask.shape).to(self.device) - mask
|
| 256 |
+
|
| 257 |
+
def masked(self, A, mask):
|
| 258 |
+
return (A / 2 + 0.5) * mask * 2 - 1
|
| 259 |
+
|
| 260 |
+
def add_with_mask(self, A, B, mask):
|
| 261 |
+
return ((A / 2 + 0.5) * mask + (B / 2 + 0.5) * (torch.ones(mask.shape).to(self.device) - mask)) * 2 - 1
|
| 262 |
+
|
| 263 |
+
def addone_with_mask(self, A, mask):
|
| 264 |
+
return ((A / 2 + 0.5) * mask + (torch.ones(mask.shape).to(self.device) - mask)) * 2 - 1
|
| 265 |
+
|
| 266 |
+
def partCombiner(self, eyel, eyer, nose, mouth, average_pos=False, comb_op=1, region_enm=0, cmaskel=None,
|
| 267 |
+
cmasker=None, cmaskno=None, cmaskmo=None):
|
| 268 |
'''
|
| 269 |
x y
|
| 270 |
100.571 123.429
|
|
|
|
| 278 |
if comb_op == 0:
|
| 279 |
# use max pooling, pad black for eyes etc
|
| 280 |
padvalue = -1
|
| 281 |
+
if region_enm in [1, 2]:
|
| 282 |
eyel = eyel * cmaskel
|
| 283 |
eyer = eyer * cmasker
|
| 284 |
nose = nose * cmaskno
|
|
|
|
| 286 |
else:
|
| 287 |
# use min pooling, pad white for eyes etc
|
| 288 |
padvalue = 1
|
| 289 |
+
if region_enm in [1, 2]:
|
| 290 |
eyel = self.addone_with_mask(eyel, cmaskel)
|
| 291 |
eyer = self.addone_with_mask(eyer, cmasker)
|
| 292 |
nose = self.addone_with_mask(nose, cmaskno)
|
| 293 |
mouth = self.addone_with_mask(mouth, cmaskmo)
|
| 294 |
+
if region_enm in [0, 1]: # need to pad
|
| 295 |
IMAGE_SIZE = self.opt.fineSize
|
| 296 |
ratio = IMAGE_SIZE / 256
|
| 297 |
EYE_W = self.opt.EYE_W * ratio
|
|
|
|
| 300 |
NOSE_H = self.opt.NOSE_H * ratio
|
| 301 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 302 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 303 |
+
bs, nc, _, _ = eyel.shape
|
| 304 |
+
eyel_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 305 |
+
eyer_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 306 |
+
nose_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 307 |
+
mouth_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 308 |
for i in range(bs):
|
| 309 |
if not average_pos:
|
| 310 |
+
center = self.center[i] # x,y
|
| 311 |
+
else: # if average_pos = True
|
| 312 |
+
center = torch.tensor([[101, 123 - 4], [155, 123 - 4], [128, 156 - NOSE_H / 2 + 16], [128, 185]])
|
| 313 |
+
eyel_p[i] = torch.nn.ConstantPad2d((int(center[0, 0] - EYE_W / 2 - 1),
|
| 314 |
+
int(IMAGE_SIZE - (center[0, 0] + EYE_W / 2 - 1)),
|
| 315 |
+
int(center[0, 1] - EYE_H / 2 - 1),
|
| 316 |
+
int(IMAGE_SIZE - (center[0, 1] + EYE_H / 2 - 1))), -1)(eyel[i])
|
| 317 |
+
eyer_p[i] = torch.nn.ConstantPad2d((int(center[1, 0] - EYE_W / 2 - 1),
|
| 318 |
+
int(IMAGE_SIZE - (center[1, 0] + EYE_W / 2 - 1)),
|
| 319 |
+
int(center[1, 1] - EYE_H / 2 - 1),
|
| 320 |
+
int(IMAGE_SIZE - (center[1, 1] + EYE_H / 2 - 1))), -1)(eyer[i])
|
| 321 |
+
nose_p[i] = torch.nn.ConstantPad2d((int(center[2, 0] - NOSE_W / 2 - 1),
|
| 322 |
+
int(IMAGE_SIZE - (center[2, 0] + NOSE_W / 2 - 1)),
|
| 323 |
+
int(center[2, 1] - NOSE_H / 2 - 1),
|
| 324 |
+
int(IMAGE_SIZE - (center[2, 1] + NOSE_H / 2 - 1))), -1)(nose[i])
|
| 325 |
+
mouth_p[i] = torch.nn.ConstantPad2d((int(center[3, 0] - MOUTH_W / 2 - 1),
|
| 326 |
+
int(IMAGE_SIZE - (center[3, 0] + MOUTH_W / 2 - 1)),
|
| 327 |
+
int(center[3, 1] - MOUTH_H / 2 - 1),
|
| 328 |
+
int(IMAGE_SIZE - (center[3, 1] + MOUTH_H / 2 - 1))), -1)(mouth[i])
|
| 329 |
elif region_enm in [2]:
|
| 330 |
eyel_p = eyel
|
| 331 |
eyer_p = eyer
|
|
|
|
| 342 |
eye_nose = torch.min(eyes, nose_p)
|
| 343 |
result = torch.min(eye_nose, mouth_p)
|
| 344 |
return result
|
| 345 |
+
|
| 346 |
+
def partCombiner2(self, eyel, eyer, nose, mouth, hair, mask, comb_op=1, region_enm=0, cmaskel=None, cmasker=None,
|
| 347 |
+
cmaskno=None, cmaskmo=None):
|
| 348 |
if comb_op == 0:
|
| 349 |
# use max pooling, pad black for eyes etc
|
| 350 |
padvalue = -1
|
| 351 |
hair = self.masked(hair, mask)
|
| 352 |
+
if region_enm in [1, 2]:
|
| 353 |
eyel = eyel * cmaskel
|
| 354 |
eyer = eyer * cmasker
|
| 355 |
nose = nose * cmaskno
|
|
|
|
| 358 |
# use min pooling, pad white for eyes etc
|
| 359 |
padvalue = 1
|
| 360 |
hair = self.addone_with_mask(hair, mask)
|
| 361 |
+
if region_enm in [1, 2]:
|
| 362 |
eyel = self.addone_with_mask(eyel, cmaskel)
|
| 363 |
eyer = self.addone_with_mask(eyer, cmasker)
|
| 364 |
nose = self.addone_with_mask(nose, cmaskno)
|
| 365 |
mouth = self.addone_with_mask(mouth, cmaskmo)
|
| 366 |
+
if region_enm in [0, 1]: # need to pad
|
| 367 |
IMAGE_SIZE = self.opt.fineSize
|
| 368 |
ratio = IMAGE_SIZE / 256
|
| 369 |
EYE_W = self.opt.EYE_W * ratio
|
|
|
|
| 372 |
NOSE_H = self.opt.NOSE_H * ratio
|
| 373 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 374 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 375 |
+
bs, nc, _, _ = eyel.shape
|
| 376 |
+
eyel_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 377 |
+
eyer_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 378 |
+
nose_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 379 |
+
mouth_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 380 |
for i in range(bs):
|
| 381 |
+
center = self.center[i] # x,y
|
| 382 |
+
eyel_p[i] = torch.nn.ConstantPad2d((center[0, 0] - EYE_W / 2, IMAGE_SIZE - (center[0, 0] + EYE_W / 2),
|
| 383 |
+
center[0, 1] - EYE_H / 2, IMAGE_SIZE - (center[0, 1] + EYE_H / 2)),
|
| 384 |
+
padvalue)(eyel[i])
|
| 385 |
+
eyer_p[i] = torch.nn.ConstantPad2d((center[1, 0] - EYE_W / 2, IMAGE_SIZE - (center[1, 0] + EYE_W / 2),
|
| 386 |
+
center[1, 1] - EYE_H / 2, IMAGE_SIZE - (center[1, 1] + EYE_H / 2)),
|
| 387 |
+
padvalue)(eyer[i])
|
| 388 |
+
nose_p[i] = torch.nn.ConstantPad2d((center[2, 0] - NOSE_W / 2, IMAGE_SIZE - (center[2, 0] + NOSE_W / 2),
|
| 389 |
+
center[2, 1] - NOSE_H / 2,
|
| 390 |
+
IMAGE_SIZE - (center[2, 1] + NOSE_H / 2)), padvalue)(nose[i])
|
| 391 |
+
mouth_p[i] = torch.nn.ConstantPad2d((center[3, 0] - MOUTH_W / 2,
|
| 392 |
+
IMAGE_SIZE - (center[3, 0] + MOUTH_W / 2),
|
| 393 |
+
center[3, 1] - MOUTH_H / 2,
|
| 394 |
+
IMAGE_SIZE - (center[3, 1] + MOUTH_H / 2)), padvalue)(mouth[i])
|
| 395 |
elif region_enm in [2]:
|
| 396 |
eyel_p = eyel
|
| 397 |
eyer_p = eyer
|
|
|
|
| 402 |
eyes = torch.max(eyel_p, eyer_p)
|
| 403 |
eye_nose = torch.max(eyes, nose_p)
|
| 404 |
eye_nose_mouth = torch.max(eye_nose, mouth_p)
|
| 405 |
+
result = torch.max(hair, eye_nose_mouth)
|
| 406 |
else:
|
| 407 |
# use min pooling
|
| 408 |
eyes = torch.min(eyel_p, eyer_p)
|
| 409 |
eye_nose = torch.min(eyes, nose_p)
|
| 410 |
eye_nose_mouth = torch.min(eye_nose, mouth_p)
|
| 411 |
+
result = torch.min(hair, eye_nose_mouth)
|
| 412 |
return result
|
| 413 |
+
|
| 414 |
+
def partCombiner2_bg(self, eyel, eyer, nose, mouth, hair, bg, maskh, maskb, comb_op=1, region_enm=0, cmaskel=None,
|
| 415 |
+
cmasker=None, cmaskno=None, cmaskmo=None):
|
| 416 |
if comb_op == 0:
|
| 417 |
# use max pooling, pad black for eyes etc
|
| 418 |
padvalue = -1
|
| 419 |
hair = self.masked(hair, maskh)
|
| 420 |
bg = self.masked(bg, maskb)
|
| 421 |
+
if region_enm in [1, 2]:
|
| 422 |
eyel = eyel * cmaskel
|
| 423 |
eyer = eyer * cmasker
|
| 424 |
nose = nose * cmaskno
|
|
|
|
| 428 |
padvalue = 1
|
| 429 |
hair = self.addone_with_mask(hair, maskh)
|
| 430 |
bg = self.addone_with_mask(bg, maskb)
|
| 431 |
+
if region_enm in [1, 2]:
|
| 432 |
eyel = self.addone_with_mask(eyel, cmaskel)
|
| 433 |
eyer = self.addone_with_mask(eyer, cmasker)
|
| 434 |
nose = self.addone_with_mask(nose, cmaskno)
|
| 435 |
mouth = self.addone_with_mask(mouth, cmaskmo)
|
| 436 |
+
if region_enm in [0, 1]: # need to pad to full size
|
| 437 |
IMAGE_SIZE = self.opt.fineSize
|
| 438 |
ratio = IMAGE_SIZE / 256
|
| 439 |
EYE_W = self.opt.EYE_W * ratio
|
|
|
|
| 442 |
NOSE_H = self.opt.NOSE_H * ratio
|
| 443 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 444 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 445 |
+
bs, nc, _, _ = eyel.shape
|
| 446 |
+
eyel_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 447 |
+
eyer_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 448 |
+
nose_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 449 |
+
mouth_p = torch.ones((bs, nc, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
|
| 450 |
for i in range(bs):
|
| 451 |
+
center = self.center[i] # x,y
|
| 452 |
+
eyel_p[i] = torch.nn.ConstantPad2d((int(center[0, 0] - EYE_W / 2),
|
| 453 |
+
IMAGE_SIZE - int(center[0, 0] + EYE_W / 2),
|
| 454 |
+
int(center[0, 1] - EYE_H / 2),
|
| 455 |
+
IMAGE_SIZE - int(center[0, 1] + EYE_H / 2)), padvalue)(eyel[i])
|
| 456 |
+
eyer_p[i] = torch.nn.ConstantPad2d((int(center[1, 0] - EYE_W / 2),
|
| 457 |
+
IMAGE_SIZE - int(center[1, 0] + EYE_W / 2),
|
| 458 |
+
int(center[1, 1] - EYE_H / 2),
|
| 459 |
+
IMAGE_SIZE - int(center[1, 1] + EYE_H / 2)), padvalue)(eyer[i])
|
| 460 |
+
nose_p[i] = torch.nn.ConstantPad2d((int(center[2, 0] - NOSE_W / 2),
|
| 461 |
+
IMAGE_SIZE - int(center[2, 0] + NOSE_W / 2),
|
| 462 |
+
int(center[2, 1] - NOSE_H / 2),
|
| 463 |
+
IMAGE_SIZE - int(center[2, 1] + NOSE_H / 2)), padvalue)(nose[i])
|
| 464 |
+
mouth_p[i] = torch.nn.ConstantPad2d((int(center[3, 0] - MOUTH_W / 2),
|
| 465 |
+
IMAGE_SIZE - int(center[3, 0] + MOUTH_W / 2),
|
| 466 |
+
int(center[3, 1] - MOUTH_H / 2),
|
| 467 |
+
IMAGE_SIZE - int(center[3, 1] + MOUTH_H / 2)), padvalue)(mouth[i])
|
| 468 |
elif region_enm in [2]:
|
| 469 |
eyel_p = eyel
|
| 470 |
eyer_p = eyer
|
|
|
|
| 474 |
eyes = torch.max(eyel_p, eyer_p)
|
| 475 |
eye_nose = torch.max(eyes, nose_p)
|
| 476 |
eye_nose_mouth = torch.max(eye_nose, mouth_p)
|
| 477 |
+
eye_nose_mouth_hair = torch.max(hair, eye_nose_mouth)
|
| 478 |
+
result = torch.max(bg, eye_nose_mouth_hair)
|
| 479 |
else:
|
| 480 |
eyes = torch.min(eyel_p, eyer_p)
|
| 481 |
eye_nose = torch.min(eyes, nose_p)
|
| 482 |
eye_nose_mouth = torch.min(eye_nose, mouth_p)
|
| 483 |
+
eye_nose_mouth_hair = torch.min(hair, eye_nose_mouth)
|
| 484 |
+
result = torch.min(bg, eye_nose_mouth_hair)
|
| 485 |
return result
|
| 486 |
+
|
| 487 |
+
def partCombiner3(self, face, hair, maskf, maskh, comb_op=1):
|
| 488 |
if comb_op == 0:
|
| 489 |
# use max pooling, pad black etc
|
| 490 |
padvalue = -1
|
|
|
|
| 496 |
face = self.addone_with_mask(face, maskf)
|
| 497 |
hair = self.addone_with_mask(hair, maskh)
|
| 498 |
if comb_op == 0:
|
| 499 |
+
result = torch.max(face, hair)
|
| 500 |
else:
|
| 501 |
+
result = torch.min(face, hair)
|
| 502 |
return result
|
| 503 |
|
|
|
|
| 504 |
def tocv2(ts):
|
| 505 |
+
img = (ts.numpy() / 2 + 0.5) * 255
|
| 506 |
img = img.astype('uint8')
|
| 507 |
+
img = np.transpose(img, (1, 2, 0))
|
| 508 |
+
img = img[:, :, ::-1] # rgb->bgr
|
| 509 |
return img
|
| 510 |
+
|
| 511 |
def totor(img):
|
| 512 |
+
img = img[:, :, ::-1]
|
| 513 |
tor = transforms.ToTensor()(img)
|
| 514 |
tor = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tor)
|
| 515 |
return tor
|
| 516 |
+
|
| 517 |
+
def ContinuityForTest(self, real=0):
|
|
|
|
| 518 |
# Patch-based
|
| 519 |
self.get_patches()
|
| 520 |
self.outputs = self.netRegressor(self.fake_B_patches)
|
|
|
|
| 529 |
self.get_patches_real()
|
| 530 |
self.outputs2 = self.netRegressor(self.real_B_patches)
|
| 531 |
line_continuity2 = torch.mean(self.outputs2)
|
| 532 |
+
file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch),
|
| 533 |
+
'continuity-r.txt')
|
| 534 |
message = '%s %.04f' % (self.image_paths[0], line_continuity2)
|
| 535 |
with open(file_name, 'a+') as c_file:
|
| 536 |
c_file.write(message)
|
| 537 |
c_file.write('\n')
|
| 538 |
+
|
| 539 |
+
def getLocalParts(self, fakeAB):
|
| 540 |
+
bs, nc, _, _ = fakeAB.shape # dtype torch.float32
|
| 541 |
ncr = int(nc / self.opt.output_nc)
|
| 542 |
+
if self.opt.region_enm in [0, 1]:
|
| 543 |
ratio = self.opt.fineSize / 256
|
| 544 |
EYE_H = self.opt.EYE_H * ratio
|
| 545 |
EYE_W = self.opt.EYE_W * ratio
|
|
|
|
| 547 |
NOSE_W = self.opt.NOSE_W * ratio
|
| 548 |
MOUTH_H = self.opt.MOUTH_H * ratio
|
| 549 |
MOUTH_W = self.opt.MOUTH_W * ratio
|
| 550 |
+
eyel = torch.ones((bs, nc, int(EYE_H), int(EYE_W))).to(self.device)
|
| 551 |
+
eyer = torch.ones((bs, nc, int(EYE_H), int(EYE_W))).to(self.device)
|
| 552 |
+
nose = torch.ones((bs, nc, int(NOSE_H), int(NOSE_W))).to(self.device)
|
| 553 |
+
mouth = torch.ones((bs, nc, int(MOUTH_H), int(MOUTH_W))).to(self.device)
|
| 554 |
for i in range(bs):
|
| 555 |
center = self.center[i]
|
| 556 |
+
eyel[i] = fakeAB[i, :, center[0, 1] - EYE_H / 2:center[0, 1] + EYE_H / 2,
|
| 557 |
+
center[0, 0] - EYE_W / 2:center[0, 0] + EYE_W / 2]
|
| 558 |
+
eyer[i] = fakeAB[i, :, center[1, 1] - EYE_H / 2:center[1, 1] + EYE_H / 2,
|
| 559 |
+
center[1, 0] - EYE_W / 2:center[1, 0] + EYE_W / 2]
|
| 560 |
+
nose[i] = fakeAB[i, :, center[2, 1] - NOSE_H / 2:center[2, 1] + NOSE_H / 2,
|
| 561 |
+
center[2, 0] - NOSE_W / 2:center[2, 0] + NOSE_W / 2]
|
| 562 |
+
mouth[i] = fakeAB[i, :, center[3, 1] - MOUTH_H / 2:center[3, 1] + MOUTH_H / 2,
|
| 563 |
+
center[3, 0] - MOUTH_W / 2:center[3, 0] + MOUTH_W / 2]
|
| 564 |
elif self.opt.region_enm in [2]:
|
| 565 |
+
eyel = (fakeAB / 2 + 0.5) * self.cmaskel.repeat(1, ncr, 1, 1) * 2 - 1
|
| 566 |
+
eyer = (fakeAB / 2 + 0.5) * self.cmasker.repeat(1, ncr, 1, 1) * 2 - 1
|
| 567 |
+
nose = (fakeAB / 2 + 0.5) * self.cmask.repeat(1, ncr, 1, 1) * 2 - 1
|
| 568 |
+
mouth = (fakeAB / 2 + 0.5) * self.cmaskmo.repeat(1, ncr, 1, 1) * 2 - 1
|
| 569 |
+
hair = (fakeAB / 2 + 0.5) * self.mask.repeat(1, ncr, 1, 1) * self.mask2.repeat(1, ncr, 1, 1) * 2 - 1
|
| 570 |
+
bg = (fakeAB / 2 + 0.5) * (torch.ones(fakeAB.shape).to(self.device) - self.mask2.repeat(1, ncr, 1, 1)) * 2 - 1
|
| 571 |
return eyel, eyer, nose, mouth, hair, bg
|
| 572 |
+
|
| 573 |
+
def getaddw(self, local_name):
|
| 574 |
addw = 1
|
| 575 |
+
if local_name in ['DLEyel', 'DLEyer', 'eyel', 'eyer', 'DLFace', 'face']:
|
| 576 |
addw = self.opt.addw_eye
|
| 577 |
elif local_name in ['DLNose', 'nose']:
|
| 578 |
addw = self.opt.addw_nose
|