#!/usr/bin/env python3 # Copyright 2023 Yi Xie # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import logging import platform import os import shutil import sys import zipfile parser = argparse.ArgumentParser( prog=os.path.basename(__file__), description='Convert a ML model to waifu2x app custom model', ) parser.add_argument('filename') required_args = parser.add_argument_group('required') required_args.add_argument('--type', choices=['esrgan_old', 'esrgan_old_lite', 'real_esrgan', 'real_esrgan_compact', 'esrgan_plus'], required=True, help='Type of the model') required_args.add_argument('--name', type=str, required=True, help='Name of the model') required_args.add_argument('--scale', type=int, required=True, help='Scale factor of the model') required_args.add_argument('--out-dir', type=str, required=True, help='Output directory') optional_args = parser.add_argument_group('optional') optional_args.add_argument('--monochrome', action='store_true', help='Input model is monochrome (single channel)') optional_args.add_argument('--has-cuda', action='store_true', help='Input model contains CUDA object') optional_args.add_argument('--num-features', type=int, help='Override number of features for (Real-)ESRGAN model') optional_args.add_argument('--num-blocks', type=int, help='Override number of blocks for (Real-)ESRGAN model') optional_args.add_argument('--num-convs', type=int, help='Override number of conv layers for Real-ESRGAN Compact model') optional_args.add_argument('--shuffle-factor', type=int, help='Shuffle input channels in ESRGAN model') optional_args.add_argument('--input-size', type=int, default=256, help='Input size (both width and height), default to 256') optional_args.add_argument('--shrink-size', type=int, default=20, help='Shrink size (applied to all 4 sides on input), default to 20') optional_args.add_argument('--description', type=str, required=False, help='Description of the model, supports Markdown') optional_args.add_argument('--source', type=str, required=False, help='Source of the model, supports Markdown') optional_args.add_argument('--author', type=str, required=False, help='Author of the model, supports Markdown') optional_args.add_argument('--license', type=str, required=False, help='License of the model, supports Markdown') optional_args.add_argument('--info-md', type=str, required=False, help='Use custom info.md instead of individual flags') optional_args.add_argument('--no-delete-mlmodel', action='store_true', help='Don\'t delete the intermediate Core ML model file') args = parser.parse_args() logger = logging.getLogger('converter') logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.INFO) formatter = logging.Formatter('%(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) if args.input_size % 4 != 0: logger.fatal('Input size must be multiple of 4') sys.exit(-1) if args.shrink_size < 0: logger.fatal('Shrink size must not be < 0') sys.exit(-1) if args.input_size - 2 * args.shrink_size < 4: logger.fatal('Input size after shrinking is too small') sys.exit(-1) os.makedirs(args.out_dir, exist_ok=True) import coremltools as ct import torch torch_model = None input_tensor = None output_tensor = None device = torch.device('cpu') if platform.system() == 'Darwin' and torch.backends.mps.is_available(): device = torch.device('mps') logger.info('Using torch device mps') elif torch.cuda.is_available(): device = torch.device('cuda') logger.info('Using torch device cuda') else: logger.info('Using torch device cpu, please be patient') logger.info('Creating model architecture') in_channels = 3 out_channels = 3 model_scale = args.scale if args.monochrome: in_channels = 1 out_channels = 1 if args.shuffle_factor: in_channels *= args.shuffle_factor * args.shuffle_factor model_scale *= args.shuffle_factor num_features = 64 num_blocks = 23 num_convs = 16 shuffle_factor = None if args.type == 'esrgan_old_lite': num_features = 32 num_blocks = 12 if args.num_features is not None: num_features = args.num_features if args.num_blocks is not None: num_blocks = args.num_blocks if args.num_convs is not None: num_convs = args.num_convs if args.type == 'esrgan_old' or args.type == 'esrgan_old_lite': from esrgan_old import architecture torch_model = architecture.RRDB_Net( in_channels, out_channels, num_features, num_blocks, gc=32, upscale=model_scale, norm_type=None, act_type='leakyrelu', mode='CNA', res_scale=1, upsample_mode='upconv') elif args.type == 'real_esrgan': from basicsr.archs.rrdbnet_arch import RRDBNet torch_model = RRDBNet(num_in_ch=in_channels, num_out_ch=out_channels, num_feat=num_features, num_block=num_blocks, num_grow_ch=32, scale=args.scale) elif args.type == 'real_esrgan_compact': from basicsr.archs.srvgg_arch import SRVGGNetCompact torch_model = SRVGGNetCompact(num_in_ch=in_channels, num_out_ch=out_channels, num_feat=num_features, num_conv=num_convs, upscale=args.scale, act_type='prelu') elif args.type == 'esrgan_plus': from esrgan_plus.codes.models.modules.architecture import RRDBNet torch_model = RRDBNet(in_nc=in_channels, out_nc=out_channels, nf=num_features, nb=num_blocks, gc=32, upscale=args.scale) else: logger.fatal('Unknown model type: %s', args.type) sys.exit(-1) logger.info('Loading weights') loadnet = None if args.has_cuda: loadnet = torch.load(args.filename, map_location=device) else: loadnet = torch.load(args.filename) if 'params_ema' in loadnet: loadnet = loadnet['params_ema'] elif 'params' in loadnet: loadnet = loadnet['params'] def mod2normal(state_dict): # this code is copied from https://github.com/victorca25/iNNfer if 'conv_first.weight' in state_dict: crt_net = {} items = list(state_dict) crt_net['model.0.weight'] = state_dict['conv_first.weight'] crt_net['model.0.bias'] = state_dict['conv_first.bias'] for k in items.copy(): if 'RDB' in k: ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') if '.weight' in k: ori_k = ori_k.replace('.weight', '.0.weight') elif '.bias' in k: ori_k = ori_k.replace('.bias', '.0.bias') crt_net[ori_k] = state_dict[k] items.remove(k) crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] crt_net['model.3.weight'] = state_dict['upconv1.weight'] crt_net['model.3.bias'] = state_dict['upconv1.bias'] crt_net['model.6.weight'] = state_dict['upconv2.weight'] crt_net['model.6.bias'] = state_dict['upconv2.bias'] crt_net['model.8.weight'] = state_dict['HRconv.weight'] crt_net['model.8.bias'] = state_dict['HRconv.bias'] crt_net['model.10.weight'] = state_dict['conv_last.weight'] crt_net['model.10.bias'] = state_dict['conv_last.bias'] state_dict = crt_net return state_dict try: torch_model.load_state_dict(loadnet, strict=True) except Exception as e: if 'conv_first.weight' in loadnet: loadnet = mod2normal(loadnet) torch_model.load_state_dict(loadnet, strict=True) else: raise e if args.monochrome: from torch import nn class MonochromeWrapper(nn.Module): def __init__(self, model: nn.Module): super(MonochromeWrapper, self).__init__() self.model = model def forward(self, x: torch.Tensor): x = torch.mean(x, dim=1, keepdim=True) x = self.model(x) x = x.repeat([1, 3, 1, 1]) return x torch_model = MonochromeWrapper(torch_model) if args.shuffle_factor: from torch import nn # Source: https://github.com/chaiNNer-org/spandrel/blob/cb2f03459819ce114c52e328b7ac9bb2812f205a/libs/spandrel/spandrel/architectures/__arch_helpers/padding.py def pad_to_multiple( tensor: torch.Tensor, multiple: int, *, mode: str, value: float = 0.0, ) -> torch.Tensor: _, _, h, w = tensor.size() pad_h = (multiple - h % multiple) % multiple pad_w = (multiple - w % multiple) % multiple if pad_h or pad_w: return nn.pad(tensor, (0, pad_w, 0, pad_h), mode, value) return tensor class ShuffleWrapper(nn.Module): def __init__(self, model: nn.Module): super(ShuffleWrapper, self).__init__() self.model = model def forward(self, x: torch.Tensor): _, _, h, w = x.size() x = pad_to_multiple(x, args.shuffle_factor, mode="reflect") x = torch.pixel_unshuffle(x, downscale_factor=args.shuffle_factor) x = self.model(x) return x[:, :, : h * model_scale, : w * model_scale] torch_model = ShuffleWrapper(torch_model) logger.info('Tracing model, will take a long time and a lot of RAM') torch_model.eval() torch_model = torch_model.to(device) example_input = torch.zeros(1, 3, 16, 16) example_input = example_input.to(device) traced_model = torch.jit.trace(torch_model, example_input) out = traced_model(example_input) logger.info('Successfully traced model') input_size = example_input.shape[-1] output_size = out.shape[-1] if args.scale != output_size / input_size: logger.fatal('Expected output scale to be %d, but is actually %.2f', args.scale, output_size / input_size) sys.exit(-1) logger.info('Converting to Core ML') input_shape = [1, 3, args.input_size, args.input_size] output_size = args.input_size * args.scale output_shape = [1, 3, output_size, output_size] minimum_deployment_target = None if args.shuffle_factor: minimum_deployment_target = ct.target.iOS16 model = ct.convert( traced_model, convert_to="mlprogram", inputs=[ct.TensorType(shape=input_shape)], minimum_deployment_target=minimum_deployment_target ) model_name = args.filename.split('/')[-1].split('.')[0] mlmodel_file = args.out_dir + '/' + model_name + '.mlpackage' model.save(mlmodel_file) logger.info('Packaging model') spec = model.get_spec() input_name = spec.description.input[0].name output_name = spec.description.output[0].name logger.debug('Model input name: %s, size: %s', input_name, args.input_size) output_size_shrinked = (args.input_size - 2 * args.shrink_size) * args.scale logger.debug('Model output name: %s, size: %s, after shrinking: %s', output_name, output_size, output_size_shrinked) manifest = { "version": 1, "name": args.name, "type": "coreml", "subModels": { "main": { "file": mlmodel_file, "inputName": input_name, "outputName": output_name } }, "dataFormat": "nchw", "inputShape": input_shape, "shrinkSize": args.shrink_size, "scale": args.scale, "alphaMode": "sameAsMain" } info_md = ''' {} === Converted by [waifu2x-ios-model-converter](https://github.com/imxieyi/waifu2x-ios-model-converter). '''.format(args.name) if args.description is not None: info_md += ''' ## Description {} '''.format(args.description) if args.author is not None: info_md += ''' ## Author {} '''.format(args.author) if args.source is not None: info_md += ''' ## Source {} '''.format(args.source) if args.license is not None: info_md += ''' ## License {} '''.format(args.license) if len(info_md) > 1024 * 1024: logger.fatal('Model info.md too large. Try to reduce license file size, etc.') sys.exit(-1) def add_folder_to_zip(folder, zipfile): for folderName, subfolders, filenames in os.walk(folder): for filename in filenames: filePath = os.path.join(folderName, filename) zipfile.write(filePath, filePath) zip_file = args.out_dir + '/' + args.name + '.wifm' with zipfile.ZipFile(zip_file, 'w', compression=zipfile.ZIP_DEFLATED) as modelzip: modelzip.writestr('manifest.json', json.dumps(manifest)) modelzip.writestr('info.md', info_md) if os.path.isfile(mlmodel_file): modelzip.write(mlmodel_file) else: add_folder_to_zip(mlmodel_file, modelzip) if not args.no_delete_mlmodel: if os.path.isfile(mlmodel_file): os.remove(mlmodel_file) else: shutil.rmtree(mlmodel_file) logger.info('Successfully converted model: %s', zip_file)