app
Browse files- app.py +1 -1
- models/ddm.py +4 -1
app.py
CHANGED
@@ -14,7 +14,7 @@ import tempfile
|
|
14 |
|
15 |
# tempfile.tempdir = "/home/dachuang/gradio/tmp/"
|
16 |
|
17 |
-
|
18 |
|
19 |
title_markdown = ("""
|
20 |
娆㈣繋鏉ュ埌鐢查鏂囨枃瀛楁紨鍙樻ā鎷熷櫒
|
|
|
14 |
|
15 |
# tempfile.tempdir = "/home/dachuang/gradio/tmp/"
|
16 |
|
17 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
18 |
|
19 |
title_markdown = ("""
|
20 |
娆㈣繋鏉ュ埌鐢查鏂囨枃瀛楁紨鍙樻ā鎷熷櫒
|
models/ddm.py
CHANGED
@@ -124,7 +124,10 @@ class DenoisingDiffusion(object):
|
|
124 |
self.num_timesteps = betas.shape[0]
|
125 |
|
126 |
def load_ddm_ckpt(self, load_path, ema=False):
|
127 |
-
|
|
|
|
|
|
|
128 |
self.start_epoch = checkpoint['epoch']
|
129 |
self.step = checkpoint['step']
|
130 |
self.model.load_state_dict(checkpoint['state_dict'], strict=True)
|
|
|
124 |
self.num_timesteps = betas.shape[0]
|
125 |
|
126 |
def load_ddm_ckpt(self, load_path, ema=False):
|
127 |
+
if self.device == torch.device('cpu'):
|
128 |
+
checkpoint = utils.logging.load_checkpoint(load_path, self.device)
|
129 |
+
else:
|
130 |
+
checkpoint = utils.logging.load_checkpoint(load_path, None)
|
131 |
self.start_epoch = checkpoint['epoch']
|
132 |
self.step = checkpoint['step']
|
133 |
self.model.load_state_dict(checkpoint['state_dict'], strict=True)
|