WensongSong commited on
Commit
20d069f
·
verified ·
1 Parent(s): a6c1364

Upload 28 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,15 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/ref_image/1.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/ref_image/2.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/ref_image/3.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/ref_image/4.png filter=lfs diff=lfs merge=lfs -text
40
+ examples/result/1.png filter=lfs diff=lfs merge=lfs -text
41
+ examples/result/2.png filter=lfs diff=lfs merge=lfs -text
42
+ examples/result/3.png filter=lfs diff=lfs merge=lfs -text
43
+ examples/result/4.png filter=lfs diff=lfs merge=lfs -text
44
+ examples/result/5.png filter=lfs diff=lfs merge=lfs -text
45
+ examples/source_image/1.png filter=lfs diff=lfs merge=lfs -text
46
+ examples/source_image/2.png filter=lfs diff=lfs merge=lfs -text
47
+ examples/source_image/3.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image, ImageFilter, ImageDraw
8
+ from huggingface_hub import snapshot_download
9
+ from diffusers import FluxFillPipeline, FluxPriorReduxPipeline
10
+ import math
11
+ from utils.utils import get_bbox_from_mask, expand_bbox, pad_to_square, box2squre, crop_back, expand_image_mask
12
+
13
+
14
+ hf_token = os.getenv("HF_TOKEN")
15
+
16
+ snapshot_download(repo_id="black-forest-labs/FLUX.1-Fill-dev", local_dir="./FLUX.1-Fill-dev", token=hf_token)
17
+ snapshot_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", local_dir="./FLUX.1-Redux-dev", token=hf_token)
18
+ snapshot_download(repo_id="WensongSong/Insert-Anything", local_dir="./insertanything_model", token=hf_token)
19
+
20
+
21
+ dtype = torch.bfloat16
22
+ size = (768, 768)
23
+
24
+ pipe = FluxFillPipeline.from_pretrained(
25
+ "./FLUX.1-Fill-dev",
26
+ torch_dtype=dtype
27
+ ).to("cuda")
28
+
29
+ pipe.load_lora_weights(
30
+ "./insertanything_model/20250321-082022_steps5000_pytorch_lora_weights.safetensors"
31
+ )
32
+
33
+
34
+ redux = FluxPriorReduxPipeline.from_pretrained("./FLUX.1-Redux-dev").to(dtype=dtype).to("cuda")
35
+
36
+
37
+
38
+ ### example #####
39
+ ref_dir='./examples/ref_image'
40
+ ref_mask_dir='./examples/ref_mask'
41
+ image_dir='./examples/source_image'
42
+ image_mask_dir='./examples/source_mask'
43
+
44
+ ref_list=[os.path.join(ref_dir,file) for file in os.listdir(ref_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file ]
45
+ ref_list.sort()
46
+
47
+ ref_mask_list=[os.path.join(ref_mask_dir,file) for file in os.listdir(ref_mask_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
48
+ ref_mask_list.sort()
49
+
50
+ image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file ]
51
+ image_list.sort()
52
+
53
+ image_mask_list=[os.path.join(image_mask_dir,file) for file in os.listdir(image_mask_dir) if '.jpg' in file or '.png' in file or '.jpeg' in file]
54
+ image_mask_list.sort()
55
+ ### example #####
56
+
57
+
58
+
59
+
60
+ def run_local(base_image, base_mask, reference_image, ref_mask, seed, base_mask_option, ref_mask_option):
61
+
62
+ if base_mask_option == "Draw Mask":
63
+ tar_image = base_image["image"]
64
+ tar_mask = base_image["mask"]
65
+ else:
66
+ tar_image = base_image["image"]
67
+ tar_mask = base_mask
68
+
69
+ if ref_mask_option == "Draw Mask":
70
+ ref_image = reference_image["image"]
71
+ ref_mask = reference_image["mask"]
72
+ else:
73
+ ref_image = reference_image["image"]
74
+ ref_mask = ref_mask
75
+
76
+
77
+ tar_image = tar_image.convert("RGB")
78
+ tar_mask = tar_mask.convert("L")
79
+ ref_image = ref_image.convert("RGB")
80
+ ref_mask = ref_mask.convert("L")
81
+
82
+ tar_image = np.asarray(tar_image)
83
+ tar_mask = np.asarray(tar_mask)
84
+ tar_mask = np.where(tar_mask > 128, 1, 0).astype(np.uint8)
85
+
86
+ ref_image = np.asarray(ref_image)
87
+ ref_mask = np.asarray(ref_mask)
88
+ ref_mask = np.where(ref_mask > 128, 1, 0).astype(np.uint8)
89
+
90
+
91
+ ref_box_yyxx = get_bbox_from_mask(ref_mask)
92
+ ref_mask_3 = np.stack([ref_mask,ref_mask,ref_mask],-1)
93
+ masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1-ref_mask_3)
94
+ y1,y2,x1,x2 = ref_box_yyxx
95
+ masked_ref_image = masked_ref_image[y1:y2,x1:x2,:]
96
+ ref_mask = ref_mask[y1:y2,x1:x2]
97
+ ratio = 1.3
98
+ masked_ref_image, ref_mask = expand_image_mask(masked_ref_image, ref_mask, ratio=ratio)
99
+
100
+
101
+ masked_ref_image = pad_to_square(masked_ref_image, pad_value = 255, random = False)
102
+
103
+ kernel = np.ones((7, 7), np.uint8)
104
+ iterations = 2
105
+ tar_mask = cv2.dilate(tar_mask, kernel, iterations=iterations)
106
+
107
+ # zome in
108
+ tar_box_yyxx = get_bbox_from_mask(tar_mask)
109
+ tar_box_yyxx = expand_bbox(tar_mask, tar_box_yyxx, ratio=1.2)
110
+
111
+ tar_box_yyxx_crop = expand_bbox(tar_image, tar_box_yyxx, ratio=2) #1.2 1.6
112
+ tar_box_yyxx_crop = box2squre(tar_image, tar_box_yyxx_crop) # crop box
113
+ y1,y2,x1,x2 = tar_box_yyxx_crop
114
+
115
+
116
+ old_tar_image = tar_image.copy()
117
+ tar_image = tar_image[y1:y2,x1:x2,:]
118
+ tar_mask = tar_mask[y1:y2,x1:x2]
119
+
120
+ H1, W1 = tar_image.shape[0], tar_image.shape[1]
121
+ # zome in
122
+
123
+
124
+ tar_mask = pad_to_square(tar_mask, pad_value=0)
125
+ tar_mask = cv2.resize(tar_mask, size)
126
+
127
+ masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), size).astype(np.uint8)
128
+ pipe_prior_output = redux(Image.fromarray(masked_ref_image))
129
+
130
+
131
+ tar_image = pad_to_square(tar_image, pad_value=255)
132
+
133
+ H2, W2 = tar_image.shape[0], tar_image.shape[1]
134
+
135
+ tar_image = cv2.resize(tar_image, size)
136
+ diptych_ref_tar = np.concatenate([masked_ref_image, tar_image], axis=1)
137
+
138
+
139
+ tar_mask = np.stack([tar_mask,tar_mask,tar_mask],-1)
140
+ mask_black = np.ones_like(tar_image) * 0
141
+ mask_diptych = np.concatenate([mask_black, tar_mask], axis=1)
142
+
143
+
144
+ diptych_ref_tar = Image.fromarray(diptych_ref_tar)
145
+ mask_diptych[mask_diptych == 1] = 255
146
+ mask_diptych = Image.fromarray(mask_diptych)
147
+
148
+
149
+
150
+ generator = torch.Generator("cuda").manual_seed(seed)
151
+ edited_image = pipe(
152
+ image=diptych_ref_tar,
153
+ mask_image=mask_diptych,
154
+ height=mask_diptych.size[1],
155
+ width=mask_diptych.size[0],
156
+ max_sequence_length=512,
157
+ generator=generator,
158
+ **pipe_prior_output,
159
+ ).images[0]
160
+
161
+
162
+
163
+ width, height = edited_image.size
164
+ left = width // 2
165
+ right = width
166
+ top = 0
167
+ bottom = height
168
+ edited_image = edited_image.crop((left, top, right, bottom))
169
+
170
+
171
+ edited_image = np.array(edited_image)
172
+ edited_image = crop_back(edited_image, old_tar_image, np.array([H1, W1, H2, W2]), np.array(tar_box_yyxx_crop))
173
+ edited_image = Image.fromarray(edited_image)
174
+
175
+
176
+ return [edited_image]
177
+
178
+ def update_ui(option):
179
+ if option == "Draw Mask":
180
+ return gr.update(visible=False), gr.update(visible=True)
181
+ else:
182
+ return gr.update(visible=True), gr.update(visible=False)
183
+
184
+
185
+ with gr.Blocks() as demo:
186
+
187
+
188
+ gr.Markdown("#  Play with InsertAnything to Insert your Target Objects! ")
189
+ gr.Markdown("# Upload / Draw Images for the Background (up) and Reference Object (down)")
190
+ gr.Markdown("### Draw mask on the background or just upload the mask.")
191
+ gr.Markdown("### Only select one of these two methods. Don't forget to click the corresponding button!!")
192
+
193
+ with gr.Row():
194
+ with gr.Column():
195
+ with gr.Row():
196
+ base_image = gr.Image(label="Background Image", source="upload", tool="sketch", type="pil",
197
+ brush_color='#FFFFFF', mask_opacity=0.5)
198
+
199
+ base_mask = gr.Image(label="Background Mask", source="upload", type="pil")
200
+
201
+ with gr.Row():
202
+ base_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Background Mask Input Option", value="Upload with Mask")
203
+
204
+ with gr.Row():
205
+ ref_image = gr.Image(label="Reference Image", source="upload", tool="sketch", type="pil",
206
+ brush_color='#FFFFFF', mask_opacity=0.5)
207
+
208
+ ref_mask = gr.Image(label="Reference Mask", source="upload", type="pil")
209
+
210
+ with gr.Row():
211
+ ref_mask_option = gr.Radio(["Draw Mask", "Upload with Mask"], label="Reference Mask Input Option", value="Upload with Mask")
212
+
213
+ baseline_gallery = gr.Gallery(label='Output', show_label=True, elem_id="gallery", height=512, columns=1)
214
+ with gr.Accordion("Advanced Option", open=True):
215
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=999999999, step=1, value=666)
216
+ gr.Markdown("### Guidelines")
217
+ gr.Markdown(" Users can try using different seeds. For example, seeds like 42 and 123456 may produce different effects.")
218
+
219
+ run_local_button = gr.Button(value="Run")
220
+
221
+
222
+ # #### example #####
223
+ num_examples = len(image_list)
224
+ for i in range(num_examples):
225
+ with gr.Row():
226
+ if i == 0:
227
+ gr.Examples([image_list[i]], inputs=[base_image], label="Examples - Background Image", examples_per_page=1)
228
+ gr.Examples([image_mask_list[i]], inputs=[base_mask], label="Examples - Background Mask", examples_per_page=1)
229
+ gr.Examples([ref_list[i]], inputs=[ref_image], label="Examples - Reference Object", examples_per_page=1)
230
+ gr.Examples([ref_mask_list[i]], inputs=[ref_mask], label="Examples - Reference Mask", examples_per_page=1)
231
+ else:
232
+ gr.Examples([image_list[i]], inputs=[base_image], examples_per_page=1, label="")
233
+ gr.Examples([image_mask_list[i]], inputs=[base_mask], examples_per_page=1, label="")
234
+ gr.Examples([ref_list[i]], inputs=[ref_image], examples_per_page=1, label="")
235
+ gr.Examples([ref_mask_list[i]], inputs=[ref_mask], examples_per_page=1, label="")
236
+ if i < num_examples - 1:
237
+ with gr.Row():
238
+ gr.HTML("<hr>")
239
+ # #### example #####
240
+
241
+ run_local_button.click(fn=run_local,
242
+ inputs=[base_image, base_mask, ref_image, ref_mask, seed, base_mask_option, ref_mask_option],
243
+ outputs=[baseline_gallery]
244
+ )
245
+ demo.launch()
examples/ref_image/1.png ADDED

Git LFS Details

  • SHA256: 30bd779c676985f416d306623ac23f648de259245dbff135d5c3d95b8faf9bde
  • Pointer size: 131 Bytes
  • Size of remote file: 312 kB
examples/ref_image/2.png ADDED

Git LFS Details

  • SHA256: 15417dd835e4b1349c5b34b0cfffc08b6cd930ad785aa8cc29bf56e8a6938588
  • Pointer size: 131 Bytes
  • Size of remote file: 131 kB
examples/ref_image/3.png ADDED

Git LFS Details

  • SHA256: b5cc18fdd8172cdad23f36a85b1e16d0f0fc8dc4784462cddd8bf50b05b73eb8
  • Pointer size: 132 Bytes
  • Size of remote file: 4.47 MB
examples/ref_image/4.png ADDED

Git LFS Details

  • SHA256: d368729e8d7368e376fe7b675298475262352b9f6bb2b8adcf50b578f766e409
  • Pointer size: 131 Bytes
  • Size of remote file: 499 kB
examples/ref_image/5.png ADDED
examples/ref_mask/1.png ADDED
examples/ref_mask/2.png ADDED
examples/ref_mask/3.png ADDED
examples/ref_mask/4.png ADDED
examples/ref_mask/5.png ADDED
examples/result/1.png ADDED

Git LFS Details

  • SHA256: aec0a16fe46f3d634b3557a4233a0bb7fd49293468ee31292e968a35bd627e13
  • Pointer size: 131 Bytes
  • Size of remote file: 987 kB
examples/result/2.png ADDED

Git LFS Details

  • SHA256: eaa39e269a7645a60d6ec3454edea3d3f740636882193836260209f41f638a21
  • Pointer size: 132 Bytes
  • Size of remote file: 5.44 MB
examples/result/3.png ADDED

Git LFS Details

  • SHA256: dd132c579e941419f1782baf8c444bd7938843e9b0bafd7f2e76f8ef3a40f995
  • Pointer size: 132 Bytes
  • Size of remote file: 1.63 MB
examples/result/4.png ADDED

Git LFS Details

  • SHA256: e145c8487e3a1c07db15187e564c5be94e4fd3701b910fa45bc1ecdc526b671a
  • Pointer size: 131 Bytes
  • Size of remote file: 787 kB
examples/result/5.png ADDED

Git LFS Details

  • SHA256: adf34a05a10fcc73b6363c4884faf2670f489efad2c35879b2f4374daeb26fcd
  • Pointer size: 131 Bytes
  • Size of remote file: 825 kB
examples/source_image/1.png ADDED

Git LFS Details

  • SHA256: 7f1a369348e51385401bfd74b40c70eb36eeef9fb5f89cd986d6d3c2aa73b778
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
examples/source_image/2.png ADDED

Git LFS Details

  • SHA256: 3118d06bf22fb9d8db8afe33f549e69217909d0108599dc59f9d17640193853d
  • Pointer size: 132 Bytes
  • Size of remote file: 6.13 MB
examples/source_image/3.png ADDED

Git LFS Details

  • SHA256: 9ca74fe50d940e2fa84ffed72f8ce5f382821d1cbd2bfde58e2813171dc04d4e
  • Pointer size: 131 Bytes
  • Size of remote file: 333 kB
examples/source_image/4.png ADDED
examples/source_image/5.png ADDED
examples/source_mask/1.png ADDED
examples/source_mask/2.png ADDED
examples/source_mask/3.png ADDED
examples/source_mask/4.png ADDED
examples/source_mask/5.png ADDED
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.5.1
2
+ torchvision==0.20.1
3
+ diffusers==0.32.2
4
+ transformers==4.50.3
5
+ peft==0.15.1
6
+ opencv-python
7
+ protobuf
8
+ sentencepiece
9
+ gradio==3.39.0
10
+ bezier
11
+ lightning==2.5.1
12
+ datasets
13
+ prodigyopt
14
+ einops
15
+ scipy
utils/utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+ def f(r, T=0.6, beta=0.1):
6
+ return np.where(r < T, beta + (1 - beta) / T * r, 1)
7
+
8
+ # Get the bounding box of the mask
9
+ def get_bbox_from_mask(mask):
10
+ h,w = mask.shape[0],mask.shape[1]
11
+
12
+ if mask.sum() < 10:
13
+ return 0,h,0,w
14
+ rows = np.any(mask,axis=1)
15
+ cols = np.any(mask,axis=0)
16
+ y1,y2 = np.where(rows)[0][[0,-1]]
17
+ x1,x2 = np.where(cols)[0][[0,-1]]
18
+ return (y1,y2,x1,x2)
19
+
20
+ # Expand the bounding box
21
+ def expand_bbox(mask, yyxx, ratio, min_crop=0):
22
+ y1,y2,x1,x2 = yyxx
23
+ H,W = mask.shape[0], mask.shape[1]
24
+
25
+ yyxx_area = (y2-y1+1) * (x2-x1+1)
26
+ r1 = yyxx_area / (H * W)
27
+ r2 = f(r1)
28
+ ratio = math.sqrt(r2 / r1)
29
+
30
+ xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
31
+ h = ratio * (y2-y1+1)
32
+ w = ratio * (x2-x1+1)
33
+ h = max(h,min_crop)
34
+ w = max(w,min_crop)
35
+
36
+ x1 = int(xc - w * 0.5)
37
+ x2 = int(xc + w * 0.5)
38
+ y1 = int(yc - h * 0.5)
39
+ y2 = int(yc + h * 0.5)
40
+
41
+ x1 = max(0,x1)
42
+ x2 = min(W,x2)
43
+ y1 = max(0,y1)
44
+ y2 = min(H,y2)
45
+ return (y1,y2,x1,x2)
46
+
47
+ # Pad the image to a square shape
48
+ def pad_to_square(image, pad_value = 255, random = False):
49
+ H,W = image.shape[0], image.shape[1]
50
+ if H == W:
51
+ return image
52
+
53
+ padd = abs(H - W)
54
+ if random:
55
+ padd_1 = int(np.random.randint(0,padd))
56
+ else:
57
+ padd_1 = int(padd / 2)
58
+ padd_2 = padd - padd_1
59
+
60
+ if len(image.shape) == 2:
61
+ if H > W:
62
+ pad_param = ((0, 0), (padd_1, padd_2))
63
+ else:
64
+ pad_param = ((padd_1, padd_2), (0, 0))
65
+ elif len(image.shape) == 3:
66
+ if H > W:
67
+ pad_param = ((0, 0), (padd_1, padd_2), (0, 0))
68
+ else:
69
+ pad_param = ((padd_1, padd_2), (0, 0), (0, 0))
70
+
71
+ image = np.pad(image, pad_param, 'constant', constant_values=pad_value)
72
+
73
+ return image
74
+
75
+ # Expand the image and mask
76
+ def expand_image_mask(image, mask, ratio=1.4):
77
+ h,w = image.shape[0], image.shape[1]
78
+ H,W = int(h * ratio), int(w * ratio)
79
+ h1 = int((H - h) // 2)
80
+ h2 = H - h - h1
81
+ w1 = int((W -w) // 2)
82
+ w2 = W -w - w1
83
+
84
+ pad_param_image = ((h1,h2),(w1,w2),(0,0))
85
+ pad_param_mask = ((h1,h2),(w1,w2))
86
+ image = np.pad(image, pad_param_image, 'constant', constant_values=255)
87
+ mask = np.pad(mask, pad_param_mask, 'constant', constant_values=0)
88
+ return image, mask
89
+
90
+ # Convert the bounding box to a square shape
91
+ def box2squre(image, box):
92
+ H,W = image.shape[0], image.shape[1]
93
+ y1,y2,x1,x2 = box
94
+ cx = (x1 + x2) // 2
95
+ cy = (y1 + y2) // 2
96
+ h,w = y2-y1, x2-x1
97
+
98
+ if h >= w:
99
+ x1 = cx - h//2
100
+ x2 = cx + h//2
101
+ else:
102
+ y1 = cy - w//2
103
+ y2 = cy + w//2
104
+ x1 = max(0,x1)
105
+ x2 = min(W,x2)
106
+ y1 = max(0,y1)
107
+ y2 = min(H,y2)
108
+ return (y1,y2,x1,x2)
109
+
110
+ # Crop the predicted image back to the original image
111
+ def crop_back( pred, tar_image, extra_sizes, tar_box_yyxx_crop):
112
+ H1, W1, H2, W2 = extra_sizes
113
+ y1,y2,x1,x2 = tar_box_yyxx_crop
114
+ pred = cv2.resize(pred, (W2, H2))
115
+ m = 2 # maigin_pixel
116
+
117
+ if W1 == H1:
118
+ if m != 0:
119
+ tar_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
120
+ else:
121
+ tar_image[y1 :y2, x1:x2, :] = pred[:, :]
122
+ return tar_image
123
+
124
+ if W1 < W2:
125
+ pad1 = int((W2 - W1) / 2)
126
+ pad2 = W2 - W1 - pad1
127
+ pred = pred[:,pad1: -pad2, :]
128
+ else:
129
+ pad1 = int((H2 - H1) / 2)
130
+ pad2 = H2 - H1 - pad1
131
+ pred = pred[pad1: -pad2, :, :]
132
+
133
+ gen_image = tar_image.copy()
134
+ if m != 0:
135
+ gen_image[y1+m :y2-m, x1+m:x2-m, :] = pred[m:-m, m:-m]
136
+ else:
137
+ gen_image[y1 :y2, x1:x2, :] = pred[:, :]
138
+
139
+ return gen_image
140
+