HaisuGuan commited on
Commit
f8e5b66
·
1 Parent(s): 4dac99b

Upload 2 files

Browse files
Files changed (2) hide show
  1. config_web.yml +53 -0
  2. data_web.py +15 -0
config_web.yml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ image_size: 64
3
+ channels: 3
4
+ num_workers: 2
5
+ train_data_dir: '/data/JGW/hsguan/JGWNET/train/' # path to directory of train data
6
+ test_data_dir: '/data/JGW/hsguan/JGWNET/selected/' # path to directory of test data
7
+ #test_data_dir: '/home/dachuang/hsguan/JGWNET/data2/test2/'
8
+ #test_save_dir: '/home/dachuang/hsguan/JGWNET/selected' # path to directory of saving restored data
9
+ test_save_dir: '/home/dachuang/hsguan/JGWNET/result90'
10
+ val_save_dir: '/data/JGW/hsguan/validation/'
11
+ grid_r: 16
12
+ conditional: True
13
+ tensorboard: '/home/dachuang/hsguan/JGWNET/90logs'
14
+
15
+ model:
16
+ in_channels: 3
17
+ out_ch: 3
18
+ ch: 128
19
+ ch_mult: [1, 2, 3, 4]
20
+ num_res_blocks: 2
21
+ attn_resolutions: [16, ]
22
+ dropout: 0.0
23
+ ema_rate: 0.999
24
+ ema: True
25
+ resamp_with_conv: True
26
+
27
+ diffusion:
28
+ beta_schedule: linear
29
+ beta_start: 0.0001
30
+ beta_end: 0.02
31
+ num_diffusion_timesteps: 1000
32
+
33
+ training:
34
+ patch_n: 8
35
+ batch_size: 8
36
+ n_epochs: 1000
37
+ n_iters: 2000000
38
+ snapshot_freq: 20 # model save frequency
39
+ validation_freq: 10000
40
+ resume: './diffusion_model_new' # path to pretrained model
41
+ seed: 61 # random seed
42
+
43
+ sampling:
44
+ batch_size: 1
45
+ last_only: True
46
+ sampling_timesteps: 100
47
+
48
+ optim:
49
+ weight_decay: 0.01
50
+ optimizer: "Adam"
51
+ lr: 0.0001
52
+ amsgrad: False
53
+ eps: 0.00000001
data_web.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torchvision
3
+ import PIL
4
+
5
+
6
+ def web_input(image):
7
+ image_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
8
+ input_img = PIL.Image.fromarray(image)
9
+ input_img = input_img.resize((100, 100), PIL.Image.LANCZOS)
10
+ wd_new, ht_new = input_img.size
11
+ wd_new = int(16 * np.ceil(wd_new / 16.0))
12
+ ht_new = int(16 * np.ceil(ht_new / 16.0))
13
+ input_img = input_img.resize((wd_new, ht_new), PIL.Image.LANCZOS)
14
+ return image_transforms(input_img).unsqueeze(0)
15
+