d0tpy commited on
Commit
93eb6fd
·
verified ·
1 Parent(s): d01a4c2
Files changed (1) hide show
  1. image_enhancer.oy +124 -0
image_enhancer.oy ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from gfpgan import GFPGANer
4
+ from tqdm import tqdm
5
+ import cv2
6
+ from enum import Enum
7
+
8
+ class EnhancementMethod(str, Enum):
9
+ gfpgan = "gfpgan"
10
+ RestoreFormer = "RestoreFormer"
11
+ codeformer = "codeformer"
12
+ realesrgan = "realesrgan"
13
+
14
+
15
+ class Enhancer:
16
+ def __init__(self, method=EnhancementMethod, background_enhancement=True, upscale=2):
17
+ # Set up RealESRGAN for background enhancement
18
+ if background_enhancement:
19
+ if upscale == 2:
20
+ if not torch.cuda.is_available(): # CPU
21
+ import warnings
22
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
23
+ 'If you really want to use it, please modify the corresponding codes.')
24
+ self.bg_upsampler = None
25
+ else:
26
+ from basicsr.archs.rrdbnet_arch import RRDBNet
27
+ from realesrgan import RealESRGANer
28
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
29
+ self.bg_upsampler = RealESRGANer(
30
+ scale=2,
31
+ model_path='https://huggingface.co/dtarnow/UPscaler/resolve/main/RealESRGAN_x2plus.pth',
32
+ model=model,
33
+ tile=400,
34
+ tile_pad=10,
35
+ pre_pad=0,
36
+ half=True) # need to set False in CPU mode
37
+ elif upscale == 4:
38
+ if not torch.cuda.is_available(): # CPU
39
+ import warnings
40
+ warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
41
+ 'If you really want to use it, please modify the corresponding codes.')
42
+ self.bg_upsampler = None
43
+ else:
44
+ from basicsr.archs.rrdbnet_arch import RRDBNet
45
+ from realesrgan import RealESRGANer
46
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
47
+ self.bg_upsampler = RealESRGANer(
48
+ scale=4,
49
+ model_path='https://huggingface.co/lllyasviel/Annotators/resolve/main/RealESRGAN_x4plus.pth',
50
+ model=model,
51
+ tile=400,
52
+ tile_pad=10,
53
+ pre_pad=0,
54
+ half=True) # need to set False in CPU mode
55
+ else:
56
+ raise ValueError(f'Wrong upscale constant {upscale}.')
57
+ else:
58
+ self.bg_upsampler = None
59
+
60
+ # Set up GPFGAN for face enhancement
61
+ if method == 'gfpgan':
62
+ self.arch = 'clean'
63
+ self.channel_multiplier = 2
64
+ self.model_name = 'GFPGANv1.4'
65
+ self.url = 'https://huggingface.co/gmk123/GFPGAN/resolve/main/GFPGANv1.4.pth'
66
+ elif method == 'RestoreFormer':
67
+ self.arch = 'RestoreFormer'
68
+ self.channel_multiplier = 2
69
+ self.model_name = 'RestoreFormer'
70
+ self.url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
71
+ elif method == 'codeformer': # TODO:
72
+ self.arch = 'CodeFormer'
73
+ self.channel_multiplier = 2
74
+ self.model_name = 'CodeFormer'
75
+ self.url = 'https://huggingface.co/sinadi/aar/resolve/main/codeformer.pth'
76
+ else:
77
+ raise ValueError(f'Wrong model version {method}.')
78
+
79
+ # Determine the model path and if the model is not available, download it
80
+ model_path = os.path.join('gfpgan/weights', self.model_name + '.pth')
81
+
82
+ if not os.path.isfile(model_path):
83
+ model_path = os.path.join('checkpoints', self.model_name + '.pth')
84
+
85
+ if not os.path.isfile(model_path):
86
+ # Download pre-trained models from url
87
+ model_path = self.url
88
+
89
+ self.restorer = GFPGANer(
90
+ model_path=model_path,
91
+ upscale=upscale,
92
+ arch=self.arch,
93
+ channel_multiplier=self.channel_multiplier,
94
+ bg_upsampler=self.bg_upsampler)
95
+
96
+
97
+ def check_image_dimensions(self, image):
98
+ # Get the dimensions of the image
99
+ height, width, _ = image.shape
100
+ return True
101
+
102
+ # Check if either dimension exceeds 2048 pixels :Todo
103
+ # if width > 2048 or height > 2048:
104
+ # return True
105
+
106
+ # else:
107
+ # print("Image dimensions are within the limit.")
108
+ # return True
109
+
110
+
111
+ def enhance(self, image):
112
+ img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
113
+ if self.check_image_dimensions(img):
114
+ cropped_faces, restored_faces, r_img = self.restorer.enhance(
115
+ img,
116
+ has_aligned=False,
117
+ only_center_face=False,
118
+ paste_back=True)
119
+ else:
120
+ r_img = img
121
+
122
+ r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
123
+
124
+ return r_img