|
import argparse |
|
import os |
|
import yaml |
|
import torch |
|
import torch.backends.cudnn as cudnn |
|
import numpy as np |
|
from data_web import web_input |
|
from models import DenoisingDiffusion, DiffusiveRestoration |
|
import utils |
|
import PIL.Image as Image |
|
from torchvision.utils import make_grid |
|
import gradio as gr |
|
import tempfile |
|
|
|
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
|
title_markdown = (""" |
|
欢迎来到甲骨文文字演变模拟器 |
|
你只需要输入一张甲骨文图片,就可以看到它在不同随机种子下演变到汉字的结果。 |
|
输入的甲骨文保证是一个完整的文字和背景,不要有其他干扰。 |
|
""") |
|
|
|
|
|
def config_get(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--config", default='config_web.yml', type=str, required=False, help="Path to the config file") |
|
args = parser.parse_args() |
|
|
|
with open(os.path.join(args.config), "r") as f: |
|
config = yaml.safe_load(f) |
|
new_config = dict2namespace(config) |
|
|
|
return new_config |
|
|
|
|
|
def dict2namespace(config): |
|
namespace = argparse.Namespace() |
|
for key, value in config.items(): |
|
if isinstance(value, dict): |
|
new_value = dict2namespace(value) |
|
else: |
|
new_value = value |
|
setattr(namespace, key, new_value) |
|
return namespace |
|
|
|
|
|
config = config_get() |
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
print("=> using device: {}".format(device)) |
|
config.device = device |
|
|
|
|
|
diffusion = DenoisingDiffusion(config, test=True) |
|
model = DiffusiveRestoration(diffusion, config) |
|
print("=> creating diffusion model") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sepia(image_web, seed): |
|
torch.manual_seed(seed) |
|
np.random.seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(seed) |
|
image, image_tmp = web_input(image_web) |
|
print("开始生成") |
|
output_image = model.web_restore(image, r=config.data.grid_r) |
|
grid = make_grid(output_image) |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
|
|
print("生成完成") |
|
return ndarr, image_tmp |
|
|
|
|
|
demo = gr.Interface(sepia, |
|
inputs=[gr.Image(label="输入甲骨文图片", height=600, width=600), |
|
gr.Number(label="随机种子", value=61)], |
|
outputs=[gr.Image(label="输出汉字图片", height=600, width=600), |
|
gr.Image(label="矫正后甲骨文图片", height=600, width=600)], |
|
title=title_markdown) |
|
demo.launch( |
|
|
|
|
|
share=True |
|
) |
|
|