from model import SwinIR import torch import cv2 import numpy as np from PIL import Image import gradio as gr # import streamlit as st def get_image(img_lq,tile=200): # img_gt = cv2.imread(input_img, cv2.IMREAD_COLOR).astype(np.float32) / 255. if tile is None: # test the image as a whole output = model(img_lq) else: # test the image tile by tile b, c, h, w = img_lq.size() tile = min(tile, h, w) # assert tile % window_size == 0, "tile size should be a multiple of window_size" tile_overlap = 32 sf = 2 stride = tile - tile_overlap h_idx_list = list(range(0, h-tile, stride)) + [h-tile] w_idx_list = list(range(0, w-tile, stride)) + [w-tile] E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq) W = torch.zeros_like(E) for h_idx in h_idx_list: for w_idx in w_idx_list: in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch) W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask) output = E.div_(W) return output model = SwinIR(embed_dim=180,window_size=8,depths=[6,6,6,6,6,6],num_heads=[6,6,6,6,6,6],mlp_ratio=2,upscale=2,upsampler='nearest+conv') model.load_state_dict(torch.load('100000_E.pth')) def sr(input): window_size = 8 img_lq = np.asarray(input).astype(np.float32) / 255. img_lq = np.transpose(img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)) # HCW-BGR to CHW-RGB img_lq = torch.from_numpy(img_lq).float().unsqueeze(0) with torch.no_grad(): _, _, h_old, w_old = img_lq.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[:, :, :h_old + h_pad, :] img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad] output = get_image(img_lq) output = output[..., :h_old * 2, :w_old * 2] # output = output[..., :138 * 2, :138 * 2] # save image output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = (output * 255.0).round().astype(np.uint8) buf = Image.fromarray(output) return buf # sr(cv2.imread('/Users/apsys/Downloads/Set14/image_SRF_2/img_006_SRF_2_LR.png', cv2.IMREAD_COLOR).astype(np.float32) / 255.) demo = gr.Interface(sr, gr.Image(type='pil'), "image") demo.launch()