Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
import torch
|
5 |
+
import torch.backends.cudnn as cudnn
|
6 |
+
import numpy as np
|
7 |
+
from data_web import web_input
|
8 |
+
from models import DenoisingDiffusion, DiffusiveRestoration
|
9 |
+
import utils
|
10 |
+
import PIL.Image as Image
|
11 |
+
from torchvision.utils import make_grid
|
12 |
+
import gradio as gr
|
13 |
+
import tempfile
|
14 |
+
|
15 |
+
# tempfile.tempdir = "/home/dachuang/gradio/tmp/"
|
16 |
+
|
17 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
18 |
+
|
19 |
+
title_markdown = ("""
|
20 |
+
欢迎来到甲骨文文字演变模拟器
|
21 |
+
你只需要输入一张甲骨文图片,就可以看到它在不同随机种子下演变到汉字的结果。
|
22 |
+
""")
|
23 |
+
|
24 |
+
|
25 |
+
def config_get():
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
# 参数配置文件路径
|
28 |
+
parser.add_argument("--config", default='config_web.yml', type=str, required=False, help="Path to the config file")
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
with open(os.path.join(args.config), "r") as f:
|
32 |
+
config = yaml.safe_load(f)
|
33 |
+
new_config = dict2namespace(config)
|
34 |
+
|
35 |
+
return new_config
|
36 |
+
|
37 |
+
|
38 |
+
def dict2namespace(config):
|
39 |
+
namespace = argparse.Namespace()
|
40 |
+
for key, value in config.items():
|
41 |
+
if isinstance(value, dict):
|
42 |
+
new_value = dict2namespace(value)
|
43 |
+
else:
|
44 |
+
new_value = value
|
45 |
+
setattr(namespace, key, new_value)
|
46 |
+
return namespace
|
47 |
+
|
48 |
+
|
49 |
+
config = config_get()
|
50 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
51 |
+
print("=> using device: {}".format(device))
|
52 |
+
config.device = device
|
53 |
+
# torch.backends.cudnn.benchmark = True
|
54 |
+
print("=> creating diffusion model")
|
55 |
+
diffusion = DenoisingDiffusion(config, test=True)
|
56 |
+
model = DiffusiveRestoration(diffusion, config)
|
57 |
+
|
58 |
+
|
59 |
+
# 加载数据
|
60 |
+
# seed = 61
|
61 |
+
# torch.manual_seed(seed)
|
62 |
+
# np.random.seed(seed)
|
63 |
+
# if torch.cuda.is_available():
|
64 |
+
# torch.cuda.manual_seed_all(seed)
|
65 |
+
# image_path = "/home/dachuang/hsguan/JGWNET/壴30.png"
|
66 |
+
# img = Image.open(image_path).convert('RGB')
|
67 |
+
# image_web = np.array(img)
|
68 |
+
# image = web_input(image_web)
|
69 |
+
|
70 |
+
# output_image = model.web_restore(image, r=config.data.grid_r)
|
71 |
+
# utils.logging.save_image(output_image, 'web/0.png')
|
72 |
+
# grid = make_grid(output_image)
|
73 |
+
# ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
74 |
+
|
75 |
+
|
76 |
+
# im = Image.fromarray(ndarr)
|
77 |
+
# im.save('web_output.png', format=None)
|
78 |
+
|
79 |
+
def sepia(image_web, seed):
|
80 |
+
torch.manual_seed(seed)
|
81 |
+
np.random.seed(seed)
|
82 |
+
if torch.cuda.is_available():
|
83 |
+
torch.cuda.manual_seed_all(seed)
|
84 |
+
image = web_input(image_web)
|
85 |
+
output_image = model.web_restore(image, r=config.data.grid_r)
|
86 |
+
grid = make_grid(output_image)
|
87 |
+
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
|
88 |
+
# torch.cuda.empty_cache()
|
89 |
+
return ndarr
|
90 |
+
|
91 |
+
|
92 |
+
demo = gr.Interface(sepia,
|
93 |
+
inputs=[gr.Image(label="输入甲骨文图片", height=600, width=600), gr.Number(label="随机种子")],
|
94 |
+
outputs=gr.Image(label="输出汉字图片", height=600, width=600),
|
95 |
+
title=title_markdown)
|
96 |
+
demo.queue().launch(
|
97 |
+
server_name="127.0.0.1",
|
98 |
+
server_port=7681,
|
99 |
+
share=True
|
100 |
+
)
|