import argparse import os import yaml import torch import torch.backends.cudnn as cudnn import numpy as np from data_web import web_input from models import DenoisingDiffusion, DiffusiveRestoration import utils import PIL.Image as Image from torchvision.utils import make_grid import gradio as gr import tempfile # tempfile.tempdir = "/home/dachuang/gradio/tmp/" os.environ["CUDA_VISIBLE_DEVICES"] = "-1" title_markdown = (""" 欢迎来到甲骨文文字演变模拟器 你只需要输入一张甲骨文图片,就可以看到它在不同随机种子下演变到汉字的结果。 输入的甲骨文保证是一个完整的文字和背景,不要有其他干扰。 """) def config_get(): parser = argparse.ArgumentParser() # 参数配置文件路径 parser.add_argument("--config", default='config_web.yml', type=str, required=False, help="Path to the config file") args = parser.parse_args() with open(os.path.join(args.config), "r") as f: config = yaml.safe_load(f) new_config = dict2namespace(config) return new_config def dict2namespace(config): namespace = argparse.Namespace() for key, value in config.items(): if isinstance(value, dict): new_value = dict2namespace(value) else: new_value = value setattr(namespace, key, new_value) return namespace config = config_get() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") print("=> using device: {}".format(device)) config.device = device # torch.backends.cudnn.benchmark = True diffusion = DenoisingDiffusion(config, test=True) model = DiffusiveRestoration(diffusion, config) print("=> creating diffusion model") # 加载数据 # seed = 61 # torch.manual_seed(seed) # np.random.seed(seed) # if torch.cuda.is_available(): # torch.cuda.manual_seed_all(seed) # image_path = "/home/dachuang/hsguan/JGWNET/壴30.png" # img = Image.open(image_path).convert('RGB') # image_web = np.array(img) # image = web_input(image_web) # output_image = model.web_restore(image, r=config.data.grid_r) # utils.logging.save_image(output_image, 'web/0.png') # grid = make_grid(output_image) # ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() # im = Image.fromarray(ndarr) # im.save('web_output.png', format=None) def sepia(image_web, seed): torch.manual_seed(seed) np.random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) image, image_tmp = web_input(image_web) print("开始生成") output_image = model.web_restore(image, r=config.data.grid_r) grid = make_grid(output_image) ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() # torch.cuda.empty_cache() print("生成完成") return ndarr, image_tmp demo = gr.Interface(sepia, inputs=[gr.Image(label="输入甲骨文图片", height=600, width=600), gr.Number(label="随机种子", value=61)], outputs=[gr.Image(label="输出汉字图片", height=600, width=600), gr.Image(label="矫正后甲骨文图片", height=600, width=600)], title=title_markdown) demo.launch( # server_name="127.0.0.1", # server_port=7681, share=True )