HaisuGuan commited on
Commit
4dac99b
·
1 Parent(s): d52ead0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import yaml
4
+ import torch
5
+ import torch.backends.cudnn as cudnn
6
+ import numpy as np
7
+ from data_web import web_input
8
+ from models import DenoisingDiffusion, DiffusiveRestoration
9
+ import utils
10
+ import PIL.Image as Image
11
+ from torchvision.utils import make_grid
12
+ import gradio as gr
13
+ import tempfile
14
+
15
+ # tempfile.tempdir = "/home/dachuang/gradio/tmp/"
16
+
17
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
18
+
19
+ title_markdown = ("""
20
+ 欢迎来到甲骨文文字演变模拟器
21
+ 你只需要输入一张甲骨文图片,就可以看到它在不同随机种子下演变到汉字的结果。
22
+ """)
23
+
24
+
25
+ def config_get():
26
+ parser = argparse.ArgumentParser()
27
+ # 参数配置文件路径
28
+ parser.add_argument("--config", default='config_web.yml', type=str, required=False, help="Path to the config file")
29
+ args = parser.parse_args()
30
+
31
+ with open(os.path.join(args.config), "r") as f:
32
+ config = yaml.safe_load(f)
33
+ new_config = dict2namespace(config)
34
+
35
+ return new_config
36
+
37
+
38
+ def dict2namespace(config):
39
+ namespace = argparse.Namespace()
40
+ for key, value in config.items():
41
+ if isinstance(value, dict):
42
+ new_value = dict2namespace(value)
43
+ else:
44
+ new_value = value
45
+ setattr(namespace, key, new_value)
46
+ return namespace
47
+
48
+
49
+ config = config_get()
50
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
51
+ print("=> using device: {}".format(device))
52
+ config.device = device
53
+ # torch.backends.cudnn.benchmark = True
54
+ print("=> creating diffusion model")
55
+ diffusion = DenoisingDiffusion(config, test=True)
56
+ model = DiffusiveRestoration(diffusion, config)
57
+
58
+
59
+ # 加载数据
60
+ # seed = 61
61
+ # torch.manual_seed(seed)
62
+ # np.random.seed(seed)
63
+ # if torch.cuda.is_available():
64
+ # torch.cuda.manual_seed_all(seed)
65
+ # image_path = "/home/dachuang/hsguan/JGWNET/壴30.png"
66
+ # img = Image.open(image_path).convert('RGB')
67
+ # image_web = np.array(img)
68
+ # image = web_input(image_web)
69
+
70
+ # output_image = model.web_restore(image, r=config.data.grid_r)
71
+ # utils.logging.save_image(output_image, 'web/0.png')
72
+ # grid = make_grid(output_image)
73
+ # ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
74
+
75
+
76
+ # im = Image.fromarray(ndarr)
77
+ # im.save('web_output.png', format=None)
78
+
79
+ def sepia(image_web, seed):
80
+ torch.manual_seed(seed)
81
+ np.random.seed(seed)
82
+ if torch.cuda.is_available():
83
+ torch.cuda.manual_seed_all(seed)
84
+ image = web_input(image_web)
85
+ output_image = model.web_restore(image, r=config.data.grid_r)
86
+ grid = make_grid(output_image)
87
+ ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
88
+ # torch.cuda.empty_cache()
89
+ return ndarr
90
+
91
+
92
+ demo = gr.Interface(sepia,
93
+ inputs=[gr.Image(label="输入甲骨文图片", height=600, width=600), gr.Number(label="随机种子")],
94
+ outputs=gr.Image(label="输出汉字图片", height=600, width=600),
95
+ title=title_markdown)
96
+ demo.queue().launch(
97
+ server_name="127.0.0.1",
98
+ server_port=7681,
99
+ share=True
100
+ )