svjack commited on
Commit
f85df78
·
verified ·
1 Parent(s): 9072a88

Create app_func.py

Browse files
Files changed (1) hide show
  1. app_func.py +458 -0
app_func.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import spaces
2
+ import contextlib
3
+ import gc
4
+ import json
5
+ import logging
6
+ import math
7
+ import os
8
+ import random
9
+ import shutil
10
+ import sys
11
+ import time
12
+ import itertools
13
+ from pathlib import Path
14
+
15
+ import cv2
16
+ import numpy as np
17
+ from PIL import Image, ImageDraw
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
+ from torch.utils.data import Dataset
22
+ from torchvision import transforms
23
+ from tqdm.auto import tqdm
24
+
25
+ import accelerate
26
+ from accelerate import Accelerator
27
+ from accelerate.logging import get_logger
28
+ from accelerate.utils import ProjectConfiguration, set_seed
29
+
30
+ from datasets import load_dataset
31
+ from huggingface_hub import create_repo, upload_folder
32
+ from packaging import version
33
+ from safetensors.torch import load_model
34
+ from peft import LoraConfig
35
+ import gradio as gr
36
+ import pandas as pd
37
+
38
+ import transformers
39
+ from transformers import (
40
+ AutoTokenizer,
41
+ PretrainedConfig,
42
+ CLIPVisionModelWithProjection,
43
+ CLIPImageProcessor,
44
+ CLIPProcessor,
45
+ )
46
+
47
+ import diffusers
48
+ from diffusers import (
49
+ AutoencoderKL,
50
+ DDPMScheduler,
51
+ ColorGuiderPixArtModel,
52
+ ColorGuiderSDModel,
53
+ UNet2DConditionModel,
54
+ PixArtTransformer2DModel,
55
+ ColorFlowPixArtAlphaPipeline,
56
+ ColorFlowSDPipeline,
57
+ UniPCMultistepScheduler,
58
+ )
59
+ from colorflow_utils.utils import *
60
+
61
+ sys.path.append('./BidirectionalTranslation')
62
+ from options.test_options import TestOptions
63
+ from models import create_model
64
+ from util import util
65
+
66
+ from huggingface_hub import snapshot_download
67
+
68
+
69
+ article = r"""
70
+ If ColorFlow is helpful, please help to ⭐ the <a href='https://github.com/TencentARC/ColorFlow' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/TencentARC/ColorFlow)](https://github.com/TencentARC/ColorFlow)
71
+ ---
72
+
73
+ 📧 **Contact**
74
+ <br>
75
+ If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
76
+
77
+ 📝 **Citation**
78
+ <br>
79
+ If our work is useful for your research, please consider citing:
80
+ ```bibtex
81
+ @misc{zhuang2024colorflow,
82
+ title={ColorFlow: Retrieval-Augmented Image Sequence Colorization},
83
+ author={Junhao Zhuang and Xuan Ju and Zhaoyang Zhang and Yong Liu and Shiyi Zhang and Chun Yuan and Ying Shan},
84
+ year={2024},
85
+ eprint={2412.11815},
86
+ archivePrefix={arXiv},
87
+ primaryClass={cs.CV},
88
+ url={https://arxiv.org/abs/2412.11815},
89
+ }
90
+ ```
91
+ """
92
+
93
+ model_global_path = snapshot_download(repo_id="TencentARC/ColorFlow", cache_dir='./colorflow/', repo_type="model")
94
+ print(model_global_path)
95
+
96
+
97
+ transform = transforms.Compose([
98
+ transforms.ToTensor(),
99
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
100
+ ])
101
+ weight_dtype = torch.float16
102
+
103
+ # line model
104
+ line_model_path = model_global_path + '/LE/erika.pth'
105
+ line_model = res_skip()
106
+ line_model.load_state_dict(torch.load(line_model_path))
107
+ line_model.eval()
108
+ line_model.cuda()
109
+
110
+ # screen model
111
+ global opt
112
+
113
+ opt = TestOptions().parse(model_global_path)
114
+ ScreenModel = create_model(opt, model_global_path)
115
+ ScreenModel.setup(opt)
116
+ ScreenModel.eval()
117
+
118
+ image_processor = CLIPImageProcessor()
119
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(model_global_path + '/image_encoder/').to('cuda')
120
+
121
+
122
+ examples = [
123
+ [
124
+ "./assets/example_6/input.jpg",
125
+ ["./assets/example_6/ref1.jpg", "./assets/example_6/ref2.jpg", "./assets/example_6/ref3.jpg"],
126
+ "GrayImage(ScreenStyle)",
127
+ "512x800",
128
+ 0,
129
+ 10
130
+ ],
131
+ [
132
+ "原神漫画2019101113203050769.jpg",
133
+ ["凯亚(20).png", "安柏 (20).png",],
134
+ "GrayImage(ScreenStyle)",
135
+ "512x800",
136
+ 0,
137
+ 10
138
+ ],
139
+ [
140
+ "./assets/example_5/input.png",
141
+ ["./assets/example_5/ref1.png", "./assets/example_5/ref2.png", "./assets/example_5/ref3.png"],
142
+ "GrayImage(ScreenStyle)",
143
+ "800x512",
144
+ 0,
145
+ 10
146
+ ],
147
+ [
148
+ "./assets/example_4/input.jpg",
149
+ ["./assets/example_4/ref1.jpg", "./assets/example_4/ref2.jpg", "./assets/example_4/ref3.jpg"],
150
+ "GrayImage(ScreenStyle)",
151
+ "640x640",
152
+ 0,
153
+ 10
154
+ ],
155
+ [
156
+ "./assets/example_3/input.png",
157
+ ["./assets/example_3/ref1.png", "./assets/example_3/ref2.png", "./assets/example_3/ref3.png"],
158
+ "GrayImage(ScreenStyle)",
159
+ "800x512",
160
+ 0,
161
+ 10
162
+ ],
163
+ [
164
+ "./assets/example_2/input.png",
165
+ ["./assets/example_2/ref1.png", "./assets/example_2/ref2.png", "./assets/example_2/ref3.png"],
166
+ "GrayImage(ScreenStyle)",
167
+ "800x512",
168
+ 0,
169
+ 10
170
+ ],
171
+ [
172
+ "./assets/example_1/input.jpg",
173
+ ["./assets/example_1/ref1.jpg", "./assets/example_1/ref2.jpg", "./assets/example_1/ref3.jpg"],
174
+ "Sketch",
175
+ "640x640",
176
+ 1,
177
+ 10
178
+ ],
179
+ [
180
+ "./assets/example_0/input.jpg",
181
+ ["./assets/example_0/ref1.jpg"],
182
+ "Sketch",
183
+ "640x640",
184
+ 1,
185
+ 10
186
+ ],
187
+ ]
188
+
189
+ global pipeline
190
+ global MultiResNetModel
191
+
192
+ #@spaces.GPU
193
+ def load_ckpt(input_style):
194
+ global pipeline
195
+ global MultiResNetModel
196
+ if input_style == "Sketch":
197
+ ckpt_path = model_global_path + '/sketch/'
198
+ rank = 128
199
+ pretrained_model_name_or_path = 'PixArt-alpha/PixArt-XL-2-1024-MS'
200
+ transformer = PixArtTransformer2DModel.from_pretrained(
201
+ pretrained_model_name_or_path, subfolder="transformer", revision=None, variant=None
202
+ )
203
+ pixart_config = get_pixart_config()
204
+
205
+ ColorGuider = ColorGuiderPixArtModel.from_pretrained(ckpt_path)
206
+
207
+ transformer_lora_config = LoraConfig(
208
+ r=rank,
209
+ lora_alpha=rank,
210
+ init_lora_weights="gaussian",
211
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "proj_in", "proj_out", "ff.net.0.proj", "ff.net.2", "proj", "linear", "linear_1", "linear_2"]
212
+ )
213
+ transformer.add_adapter(transformer_lora_config)
214
+ ckpt_key_t = torch.load(ckpt_path + 'transformer_lora.bin', map_location='cpu')
215
+ transformer.load_state_dict(ckpt_key_t, strict=False)
216
+
217
+ transformer.to('cuda', dtype=weight_dtype)
218
+ ColorGuider.to('cuda', dtype=weight_dtype)
219
+
220
+ pipeline = ColorFlowPixArtAlphaPipeline.from_pretrained(
221
+ pretrained_model_name_or_path,
222
+ transformer=transformer,
223
+ colorguider=ColorGuider,
224
+ safety_checker=None,
225
+ revision=None,
226
+ variant=None,
227
+ torch_dtype=weight_dtype,
228
+ )
229
+ pipeline = pipeline.to("cuda")
230
+ block_out_channels = [128, 128, 256, 512, 512]
231
+
232
+ MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
233
+ MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
234
+ MultiResNetModel.to('cuda', dtype=weight_dtype)
235
+
236
+ elif input_style == "GrayImage(ScreenStyle)":
237
+ ckpt_path = model_global_path + '/GraySD/'
238
+ rank = 64
239
+ pretrained_model_name_or_path = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
240
+ unet = UNet2DConditionModel.from_pretrained(
241
+ pretrained_model_name_or_path, subfolder="unet", revision=None, variant=None
242
+ )
243
+ ColorGuider = ColorGuiderSDModel.from_pretrained(ckpt_path)
244
+ ColorGuider.to('cuda', dtype=weight_dtype)
245
+ unet.to('cuda', dtype=weight_dtype)
246
+
247
+ pipeline = ColorFlowSDPipeline.from_pretrained(
248
+ pretrained_model_name_or_path,
249
+ unet=unet,
250
+ colorguider=ColorGuider,
251
+ safety_checker=None,
252
+ revision=None,
253
+ variant=None,
254
+ torch_dtype=weight_dtype,
255
+ )
256
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
257
+ unet_lora_config = LoraConfig(
258
+ r=rank,
259
+ lora_alpha=rank,
260
+ init_lora_weights="gaussian",
261
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"],#ff.net.0.proj ff.net.2
262
+ )
263
+ pipeline.unet.add_adapter(unet_lora_config)
264
+ pipeline.unet.load_state_dict(torch.load(ckpt_path + 'unet_lora.bin', map_location='cpu'), strict=False)
265
+ pipeline = pipeline.to("cuda")
266
+ block_out_channels = [128, 128, 256, 512, 512]
267
+
268
+ MultiResNetModel = MultiHiddenResNetModel(block_out_channels, len(block_out_channels))
269
+ MultiResNetModel.load_state_dict(torch.load(ckpt_path + 'MultiResNetModel.bin', map_location='cpu'), strict=False)
270
+ MultiResNetModel.to('cuda', dtype=weight_dtype)
271
+
272
+
273
+
274
+
275
+
276
+ global cur_input_style
277
+ cur_input_style = "Sketch"
278
+ load_ckpt(cur_input_style)
279
+ cur_input_style = "GrayImage(ScreenStyle)"
280
+ load_ckpt(cur_input_style)
281
+ cur_input_style = None
282
+
283
+ #@spaces.GPU
284
+ def fix_random_seeds(seed):
285
+ random.seed(seed)
286
+ np.random.seed(seed)
287
+ torch.manual_seed(seed)
288
+ if torch.cuda.is_available():
289
+ torch.cuda.manual_seed(seed)
290
+ torch.cuda.manual_seed_all(seed)
291
+
292
+ def process_multi_images(files):
293
+ images = [Image.open(file.name) for file in files]
294
+ imgs = []
295
+ for i, img in enumerate(images):
296
+ imgs.append(img)
297
+ return imgs
298
+
299
+ #@spaces.GPU
300
+ def extract_lines(image):
301
+ src = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
302
+
303
+ rows = int(np.ceil(src.shape[0] / 16)) * 16
304
+ cols = int(np.ceil(src.shape[1] / 16)) * 16
305
+
306
+ patch = np.ones((1, 1, rows, cols), dtype="float32")
307
+ patch[0, 0, 0:src.shape[0], 0:src.shape[1]] = src
308
+
309
+ tensor = torch.from_numpy(patch).cuda()
310
+
311
+ with torch.no_grad():
312
+ y = line_model(tensor)
313
+
314
+ yc = y.cpu().numpy()[0, 0, :, :]
315
+ yc[yc > 255] = 255
316
+ yc[yc < 0] = 0
317
+
318
+ outimg = yc[0:src.shape[0], 0:src.shape[1]]
319
+ outimg = outimg.astype(np.uint8)
320
+ outimg = Image.fromarray(outimg)
321
+ torch.cuda.empty_cache()
322
+ return outimg
323
+
324
+ #@spaces.GPU
325
+ def to_screen_image(input_image):
326
+ global opt
327
+ global ScreenModel
328
+ input_image = input_image.convert('RGB')
329
+ input_image = get_ScreenVAE_input(input_image, opt)
330
+ h = input_image['h']
331
+ w = input_image['w']
332
+ ScreenModel.set_input(input_image)
333
+ fake_B, fake_B2, SCR = ScreenModel.forward(AtoB=True)
334
+ images=fake_B2[:,:,:h,:w]
335
+ im = util.tensor2im(images)
336
+ image_pil = Image.fromarray(im)
337
+ torch.cuda.empty_cache()
338
+ return image_pil
339
+
340
+ #@spaces.GPU
341
+ def extract_line_image(query_image_, input_style, resolution):
342
+ if resolution == "640x640":
343
+ tar_width = 640
344
+ tar_height = 640
345
+ elif resolution == "512x800":
346
+ tar_width = 512
347
+ tar_height = 800
348
+ elif resolution == "800x512":
349
+ tar_width = 800
350
+ tar_height = 512
351
+ else:
352
+ gr.Info("Unsupported resolution")
353
+
354
+ query_image = process_image(query_image_, int(tar_width*1.5), int(tar_height*1.5))
355
+ if input_style == "GrayImage(ScreenStyle)":
356
+ extracted_line = to_screen_image(query_image)
357
+ extracted_line = Image.blend(extracted_line.convert('L').convert('RGB'), query_image.convert('L').convert('RGB'), 0.5)
358
+ input_context = extracted_line
359
+ elif input_style == "Sketch":
360
+ query_image = query_image.convert('L').convert('RGB')
361
+ extracted_line = extract_lines(query_image)
362
+ extracted_line = extracted_line.convert('L').convert('RGB')
363
+ input_context = extracted_line
364
+ torch.cuda.empty_cache()
365
+ return input_context, extracted_line, input_context
366
+
367
+ #@spaces.GPU(duration=180)
368
+ def colorize_image(VAE_input, input_context, reference_images, resolution, seed, input_style, num_inference_steps):
369
+ if VAE_input is None or input_context is None:
370
+ gr.Info("Please preprocess the image first")
371
+ raise ValueError("Please preprocess the image first")
372
+ global cur_input_style
373
+ global pipeline
374
+ global MultiResNetModel
375
+ if input_style != cur_input_style:
376
+ gr.Info(f"Loading {input_style} model...")
377
+ load_ckpt(input_style)
378
+ cur_input_style = input_style
379
+ gr.Info(f"{input_style} model loaded")
380
+ reference_images = process_multi_images(reference_images)
381
+ fix_random_seeds(seed)
382
+ if resolution == "640x640":
383
+ tar_width = 640
384
+ tar_height = 640
385
+ elif resolution == "512x800":
386
+ tar_width = 512
387
+ tar_height = 800
388
+ elif resolution == "800x512":
389
+ tar_width = 800
390
+ tar_height = 512
391
+ else:
392
+ gr.Info("Unsupported resolution")
393
+ validation_mask = Image.open('./assets/mask.png').convert('RGB').resize((tar_width*2, tar_height*2))
394
+ gr.Info("Image retrieval in progress...")
395
+ query_image_bw = process_image(input_context, int(tar_width), int(tar_height))
396
+ query_image = query_image_bw.convert('RGB')
397
+ query_image_vae = process_image(VAE_input, int(tar_width*1.5), int(tar_height*1.5))
398
+ reference_images = [process_image(ref_image, tar_width, tar_height) for ref_image in reference_images]
399
+ query_patches_pil = process_image_Q_varres(query_image, tar_width, tar_height)
400
+ reference_patches_pil = []
401
+ for reference_image in reference_images:
402
+ reference_patches_pil += process_image_ref_varres(reference_image, tar_width, tar_height)
403
+ combined_image = None
404
+ with torch.no_grad():
405
+ clip_img = image_processor(images=query_patches_pil, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
406
+ query_embeddings = image_encoder(clip_img).image_embeds
407
+ reference_patches_pil_gray = [rimg.convert('RGB').convert('RGB') for rimg in reference_patches_pil]
408
+ clip_img = image_processor(images=reference_patches_pil_gray, return_tensors="pt").pixel_values.to(image_encoder.device, dtype=image_encoder.dtype)
409
+ reference_embeddings = image_encoder(clip_img).image_embeds
410
+ cosine_similarities = F.cosine_similarity(query_embeddings.unsqueeze(1), reference_embeddings.unsqueeze(0), dim=-1)
411
+ sorted_indices = torch.argsort(cosine_similarities, descending=True, dim=1).tolist()
412
+ top_k = 3
413
+ top_k_indices = [cur_sortlist[:top_k] for cur_sortlist in sorted_indices]
414
+ combined_image = Image.new('RGB', (tar_width * 2, tar_height * 2), 'white')
415
+ combined_image.paste(query_image_bw.resize((tar_width, tar_height)), (tar_width//2, tar_height//2))
416
+ idx_table = {0:[(1,0), (0,1), (0,0)], 1:[(1,3), (0,2),(0,3)], 2:[(2,0),(3,1), (3,0)], 3:[(2,3), (3,2),(3,3)]}
417
+ for i in range(2):
418
+ for j in range(2):
419
+ idx_list = idx_table[i * 2 + j]
420
+ for k in range(top_k):
421
+ ref_index = top_k_indices[i * 2 + j][k]
422
+ idx_y = idx_list[k][0]
423
+ idx_x = idx_list[k][1]
424
+ combined_image.paste(reference_patches_pil[ref_index].resize((tar_width//2-2, tar_height//2-2)), (tar_width//2 * idx_x + 1, tar_height//2 * idx_y + 1))
425
+ gr.Info("Model inference in progress...")
426
+ generator = torch.Generator(device='cuda').manual_seed(seed)
427
+ image = pipeline(
428
+ "manga", cond_image=combined_image, cond_mask=validation_mask, num_inference_steps=num_inference_steps, generator=generator
429
+ ).images[0]
430
+ gr.Info("Post-processing image...")
431
+ with torch.no_grad():
432
+ width, height = image.size
433
+ new_width = width // 2
434
+ new_height = height // 2
435
+ left = (width - new_width) // 2
436
+ top = (height - new_height) // 2
437
+ right = left + new_width
438
+ bottom = top + new_height
439
+ center_crop = image.crop((left, top, right, bottom))
440
+ up_img = center_crop.resize(query_image_vae.size)
441
+ test_low_color = transform(up_img).unsqueeze(0).to('cuda', dtype=weight_dtype)
442
+ query_image_vae = transform(query_image_vae).unsqueeze(0).to('cuda', dtype=weight_dtype)
443
+
444
+ h_color, hidden_list_color = pipeline.vae._encode(test_low_color,return_dict = False, hidden_flag = True)
445
+ h_bw, hidden_list_bw = pipeline.vae._encode(query_image_vae, return_dict = False, hidden_flag = True)
446
+
447
+ hidden_list_double = [torch.cat((hidden_list_color[hidden_idx], hidden_list_bw[hidden_idx]), dim = 1) for hidden_idx in range(len(hidden_list_color))]
448
+
449
+
450
+ hidden_list = MultiResNetModel(hidden_list_double)
451
+ output = pipeline.vae._decode(h_color.sample(),return_dict = False, hidden_list = hidden_list)[0]
452
+
453
+ output[output > 1] = 1
454
+ output[output < -1] = -1
455
+ high_res_image = Image.fromarray(((output[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)).convert("RGB")
456
+ gr.Info("Colorization complete!")
457
+ torch.cuda.empty_cache()
458
+ return high_res_image, up_img, image, query_image_bw