Wiuhh commited on
Commit
24bace7
·
verified ·
1 Parent(s): 0a1de88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -95
app.py CHANGED
@@ -1,115 +1,164 @@
1
  import os
2
-
3
  import sys
 
 
 
 
4
  from torchvision.transforms import functional
5
- sys.modules["torchvision.transforms.functional_tensor"] = functional
6
- # //sequntila NotImplemented
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
9
  from gfpgan.utils import GFPGANer
10
  from realesrgan.utils import RealESRGANer
11
 
12
- import torch
13
- import cv2
14
- import gradio as gr
15
-
16
-
17
- #Download Required Models
18
- if not os.path.exists('realesr-general-x4v3.pth'):
19
- os.system("wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth -P .")
20
- if not os.path.exists('GFPGANv1.2.pth'):
21
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
22
- if not os.path.exists('GFPGANv1.3.pth'):
23
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
24
- if not os.path.exists('GFPGANv1.4.pth'):
25
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
26
- if not os.path.exists('RestoreFormer.pth'):
27
- os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
28
-
29
-
30
- model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
31
- model_path = 'realesr-general-x4v3.pth'
32
- half = True if torch.cuda.is_available() else False
33
- upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
34
-
35
-
36
- # Save Image to the Directory
37
- # os.makedirs('output', exist_ok=True)
38
-
39
- def upscaler(img, version, scale):
40
 
41
  try:
 
 
42
 
43
- img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
44
- if len(img.shape) == 3 and img.shape[2] == 4:
45
- img_mode = 'RGBA'
46
- elif len(img.shape) == 2:
47
- img_mode = None
48
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
49
- else:
50
- img_mode = None
51
 
52
-
53
- h, w = img.shape[0:2]
54
- if h < 300:
55
- img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
56
-
57
-
58
  face_enhancer = GFPGANer(
59
- model_path=f'{version}.pth',
60
- upscale=2,
61
- arch='RestoreFormer' if version=='RestoreFormer' else 'clean',
62
  channel_multiplier=2,
63
- bg_upsampler=upsampler
64
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
-
67
- try:
68
- _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
69
- except RuntimeError as error:
70
- print('Error', error)
71
-
72
-
73
- try:
74
- if scale != 2:
75
- interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
76
- h, w = img.shape[0:2]
77
- output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
78
- except Exception as error:
79
- print('wrong scale input.', error)
80
-
81
- # Save Image to the Directory
82
- # ext = os.path.splitext(os.path.basename(str(img)))[1]
83
- # if img_mode == 'RGBA':
84
- # ext = 'png'
85
- # else:
86
- # ext = 'jpg'
87
- #
88
- # save_path = f'output/out.{ext}'
89
- # cv2.imwrite(save_path, output)
90
- # return output, save_path
91
-
92
- output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
93
- return output
94
  except Exception as error:
95
- print('global exception', error)
96
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  if __name__ == "__main__":
99
-
100
- title = "NeuraVision ai image upscale"
101
-
102
- demo = gr.Interface(
103
- upscaler, [
104
- gr.Image(type="filepath", label="Input"),
105
- gr.Radio(['GFPGANv1.2', 'GFPGANv1.3', 'GFPGANv1.4', 'RestoreFormer'], type="value", label='version'),
106
- gr.Number(label="Rescaling factor"),
107
- ], [
108
- gr.Image(type="numpy", label="Output"),
109
- ],
110
- title=title,
111
- allow_flagging="never"
112
- )
113
-
114
  demo.queue()
115
- demo.launch()
 
1
  import os
 
2
  import sys
3
+ import tempfile
4
+ import cv2
5
+ import torch
6
+ import gradio as gr
7
  from torchvision.transforms import functional
 
 
8
 
9
+ # --- PATCH FOR COMPATIBILITY ---
10
+ sys.modules["torchvision.transforms.functional_tensor"] = functional
11
+
12
+ # --- EMBEDDED CSS FOR STYLING ---
13
+ CSS_STYLING = """
14
+ :root {
15
+ --primary: #6a35ee; --primary-dark: #4a1dcc; --secondary: #00c9ff;
16
+ --accent: #ff6b6b; --light: #f8f9ff; --dark: #1a1f36; --text: #4a5568;
17
+ --input-background-fill: var(--light) !important;
18
+ --input-border-color: #e0e0e0 !important;
19
+ --input-label-color: var(--text) !important;
20
+ }
21
+ .gradio-container { background: var(--light); font-family: 'Inter', sans-serif; }
22
+ #main-title { color: var(--dark); text-align: center; font-size: 2.5rem !important; font-weight: 900; }
23
+ #main-subtitle { color: var(--text); text-align: center; font-size: 1rem !important; margin-top: -15px; margin-bottom: 20px; }
24
+ #submit-button { background: var(--primary); color: white; font-weight: bold; border-radius: 8px !important; transition: all 0.3s ease; }
25
+ #submit-button:hover { background: var(--primary-dark); box-shadow: 0px 4px 15px rgba(106, 53, 238, 0.4); transform: translateY(-2px); }
26
+ .gr-image { border: 1px dashed var(--input-border-color) !important; border-radius: 12px !important; min-height: 300px; }
27
+ input[type="range"]::-webkit-slider-thumb { background: var(--primary) !important; }
28
+ input[type="range"]::-moz-range-thumb { background: var(--primary) !important; }
29
+ """
30
+
31
+ # --- DOWNLOAD HELPER FUNCTIONS ---
32
+ def download_file(url, dir_path, file_name):
33
+ """Downloads a file if it doesn't exist."""
34
+ os.makedirs(dir_path, exist_ok=True)
35
+ file_path = os.path.join(dir_path, file_name)
36
+ if not os.path.exists(file_path):
37
+ print(f"Downloading {file_name}...")
38
+ try:
39
+ os.system(f"wget {url} -O {file_path}")
40
+ print("Download complete.")
41
+ except Exception as e:
42
+ print(f"Error downloading {file_name}: {e}")
43
+ # In case wget is not available, you might need to use requests or urllib
44
+ # import requests
45
+ # with open(file_path, 'wb') as f:
46
+ # f.write(requests.get(url).content)
47
+ return file_path
48
+
49
+ # --- DOWNLOAD MODELS AND EXAMPLES ---
50
+ print("Checking for required files...")
51
+ # Models
52
+ models_dir = 'models'
53
+ download_file('https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', models_dir, 'realesr-general-x4v3.pth')
54
+ download_file('https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', models_dir, 'GFPGANv1.4.pth')
55
+ download_file('https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth', models_dir, 'RestoreFormer.pth')
56
+
57
+ # Example Images
58
+ examples_dir = 'examples'
59
+ example1_path = download_file('https://raw.githubusercontent.com/TencentARC/GFPGAN/master/inputs/whole_imgs/10045.png', examples_dir, 'example1.png')
60
+ example2_path = download_file('https://raw.githubusercontent.com/TencentARC/GFPGAN/master/inputs/whole_imgs/Blake_Lively.jpg', examples_dir, 'example2.jpg')
61
+
62
+ # --- LOAD MODELS INTO MEMORY ---
63
  from basicsr.archs.srvgg_arch import SRVGGNetCompact
64
  from gfpgan.utils import GFPGANer
65
  from realesrgan.utils import RealESRGANer
66
 
67
+ bg_upsampler = None
68
+ try:
69
+ model_path = os.path.join(models_dir, 'realesr-general-x4v3.pth')
70
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
71
+ half = torch.cuda.is_available()
72
+ bg_upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
73
+ print("Background Upsampler (Real-ESRGAN) loaded.")
74
+ except Exception as e:
75
+ print(f"Error loading background upsampler: {e}. The app may not work correctly.")
76
+
77
+ # --- CORE IMAGE PROCESSING FUNCTION ---
78
+ def upscale_image(img_path, version, scale):
79
+ """Enhance an image using GFPGAN and Real-ESRGAN."""
80
+ if not img_path:
81
+ raise gr.Error("Please upload an image.")
82
+ if not bg_upsampler:
83
+ raise gr.Error("Background upsampler not loaded. Cannot proceed.")
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  try:
86
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
87
+ if img is None: raise RuntimeError("Failed to read image.")
88
 
89
+ has_alpha = img.shape[2] == 4
 
 
 
 
 
 
 
90
 
 
 
 
 
 
 
91
  face_enhancer = GFPGANer(
92
+ model_path=os.path.join(models_dir, f'{version}.pth'),
93
+ upscale=2, # GFPGAN native upscale factor
94
+ arch='RestoreFormer' if version == 'RestoreFormer' else 'clean',
95
  channel_multiplier=2,
96
+ bg_upsampler=bg_upsampler
97
  )
98
+
99
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
100
+
101
+ if scale != 2:
102
+ h, w = output.shape[0:2]
103
+ target_w, target_h = int(w * scale / 2), int(h * scale / 2)
104
+ if target_w > 8000 or target_h > 8000:
105
+ raise gr.Error(f"Target size is too large. Please choose a smaller scale.")
106
+ interpolation = cv2.INTER_LANCZOS4 if scale > 2 else cv2.INTER_AREA
107
+ output = cv2.resize(output, (target_w, target_h), interpolation=interpolation)
108
+
109
+ output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
110
+ ext = 'png' if has_alpha else 'jpg'
111
+
112
+ # Save to a temporary file for download
113
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{ext}') as temp_file:
114
+ cv2.imwrite(temp_file.name, cv2.cvtColor(output_rgb, cv2.COLOR_RGB2BGR))
115
+ return output_rgb, temp_file.name
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  except Exception as error:
118
+ print(f"Error processing image: {error}")
119
+ raise gr.Error(f"An error occurred: {error}")
120
+
121
+ # --- GRADIO UI LAYOUT ---
122
+ with gr.Blocks(css=CSS_STYLING, theme=gr.themes.Base()) as demo:
123
+ gr.Markdown("<h1 id='main-title'>NeuraVision AI Image Upscaler</h1>", elem_id="main-title")
124
+ gr.Markdown("<p id='main-subtitle'>Enhance old, blurry, and low-resolution photos with AI.</p>", elem_id="main-subtitle")
125
+
126
+ with gr.Row(variant="panel"):
127
+ # LEFT COLUMN (INPUT & CONTROLS)
128
+ with gr.Column(scale=1):
129
+ input_image = gr.Image(type="filepath", label="Upload Image")
130
+
131
+ version = gr.Radio(
132
+ ['GFPGANv1.4', 'RestoreFormer'], value='GFPGANv1.4',
133
+ label='AI Model', info="v1.4 is best for general use. RestoreFormer for very old photos."
134
+ )
135
+
136
+ scale = gr.Slider(
137
+ minimum=1, maximum=8, step=0.5, value=4,
138
+ label="Upscale Factor", info="How many times larger to make the image."
139
+ )
140
+
141
+ submit_btn = gr.Button("Enhance Image", variant="primary", elem_id="submit-button")
142
+
143
+ gr.Examples(
144
+ examples=[[example1_path, "RestoreFormer", 4], [example2_path, "GFPGANv1.4", 4]],
145
+ inputs=[input_image, version, scale],
146
+ label="Click an example to start"
147
+ )
148
+
149
+ # RIGHT COLUMN (OUTPUT)
150
+ with gr.Column(scale=1):
151
+ output_image = gr.Image(type="numpy", label="Enhanced Result", interactive=False)
152
+ download_button = gr.File(label="Download Image", interactive=False)
153
+
154
+ # --- BUTTON & EVENT HANDLING ---
155
+ submit_btn.click(
156
+ fn=upscale_image,
157
+ inputs=[input_image, version, scale],
158
+ outputs=[output_image, download_button]
159
+ )
160
+ input_image.clear(lambda: (None, None), None, [output_image, download_button])
161
 
162
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  demo.queue()
164
+ demo.launch(share=True) # Set share=False if you don't need a public link