Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision | |
| from torchvision.transforms import transforms | |
| import numpy as np | |
| from PIL import Image | |
| class SRCNNModel(nn.Module): | |
| def __init__(self): | |
| super(SRCNNModel, self).__init__() | |
| self.conv1=nn.Conv2d(1,64,9,padding=4) | |
| self.conv2=nn.Conv2d(64,32,1,padding=0) | |
| self.conv3=nn.Conv2d(32,1,5,padding=2) | |
| def forward(self,x): | |
| out = F.relu(self.conv1(x)) | |
| out = F.relu(self.conv2(out)) | |
| out = self.conv3(out) | |
| return out | |
| def pred_SRCNN(model,image,device,scale_factor=2): | |
| """ | |
| model: SRCNN model | |
| image: low resolution image PILLOW image | |
| scale_factor: scale factor for resolution | |
| device: cuda or cpu | |
| """ | |
| model.to(device) | |
| model.eval() | |
| # open image | |
| # image = Image.open(image_path) | |
| # split channels | |
| y, cb, cr= image.convert('YCbCr').split() | |
| # size will be used in image transform | |
| original_size = y.size | |
| # bicubic interpolate it to the original size | |
| y_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(y) | |
| cb_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cb) | |
| cr_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(cr) | |
| # turn it into tensor and add batch dimension | |
| y_bicubic = transforms.ToTensor()(y_bicubic).to(device).unsqueeze(0) | |
| # get the y channel SRCNN prediction | |
| y_pred = model(y_bicubic) | |
| # convert it to numpy image | |
| y_pred = y_pred[0].cpu().detach().numpy() | |
| # convert it into regular image pixel values | |
| y_pred = y_pred*255 | |
| y_pred.clip(0,255) | |
| # conver y channel from array to PIL image format for merging | |
| y_pred_PIL = Image.fromarray(np.uint8(y_pred[0]),mode='L') | |
| # merge the SRCNN y channel with cb cr channels | |
| out_final = Image.merge('YCbCr',[y_pred_PIL,cb_bicubic,cr_bicubic]).convert('RGB') | |
| image_bicubic = transforms.Resize((original_size[1]*scale_factor,original_size[0]*scale_factor),interpolation=Image.BICUBIC)(image) | |
| return out_final,image_bicubic,image | |
| def main(): | |
| print("Loading SRCNN model...") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = SRCNNModel().to(device) | |
| model.load_state_dict(torch.load('SRCNNmodel_trained.pt')) | |
| model.eval() | |
| print("SRCNN model loaded!") | |
| image_path = "LR_image.png" | |
| out_final,image_bicubic,image = pred_SRCNN(model=model,image_path=image_path,device=device) | |
| image.show() | |
| out_final.show() | |
| image_bicubic.show() | |
| if __name__=="__main__": | |
| main() | |