Spaces:
Runtime error
Runtime error
ChongMou
commited on
Commit
·
4f29a2b
1
Parent(s):
0cf589a
Update demo/model.py
Browse files- demo/model.py +7 -8
demo/model.py
CHANGED
|
@@ -104,8 +104,7 @@ class Model_all:
|
|
| 104 |
self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
|
| 105 |
self.config.model.params.cond_stage_config.params.device = device
|
| 106 |
self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
|
| 107 |
-
self.
|
| 108 |
-
self.current_base_sketch = 'sd-v1-4.ckpt'
|
| 109 |
self.sampler = PLMSSampler(self.base_model)
|
| 110 |
|
| 111 |
# sketch part
|
|
@@ -144,7 +143,7 @@ class Model_all:
|
|
| 144 |
|
| 145 |
@torch.no_grad()
|
| 146 |
def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
| 147 |
-
if self.
|
| 148 |
ckpt = os.path.join("models", base_model)
|
| 149 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 150 |
if "state_dict" in pl_sd:
|
|
@@ -152,7 +151,7 @@ class Model_all:
|
|
| 152 |
else:
|
| 153 |
sd = pl_sd
|
| 154 |
self.base_model.load_state_dict(sd, strict=False)
|
| 155 |
-
self.
|
| 156 |
# del sd
|
| 157 |
# del pl_sd
|
| 158 |
con_strength = int((1-con_strength)*50)
|
|
@@ -218,7 +217,7 @@ class Model_all:
|
|
| 218 |
|
| 219 |
@torch.no_grad()
|
| 220 |
def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
| 221 |
-
if self.
|
| 222 |
ckpt = os.path.join("models", base_model)
|
| 223 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 224 |
if "state_dict" in pl_sd:
|
|
@@ -226,7 +225,7 @@ class Model_all:
|
|
| 226 |
else:
|
| 227 |
sd = pl_sd
|
| 228 |
self.base_model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
|
| 229 |
-
self.
|
| 230 |
con_strength = int((1-con_strength)*50)
|
| 231 |
if fix_sample == 'True':
|
| 232 |
seed_everything(42)
|
|
@@ -288,7 +287,7 @@ class Model_all:
|
|
| 288 |
|
| 289 |
@torch.no_grad()
|
| 290 |
def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
| 291 |
-
if self.
|
| 292 |
ckpt = os.path.join("models", base_model)
|
| 293 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 294 |
if "state_dict" in pl_sd:
|
|
@@ -296,7 +295,7 @@ class Model_all:
|
|
| 296 |
else:
|
| 297 |
sd = pl_sd
|
| 298 |
self.base_model.load_state_dict(sd, strict=False)
|
| 299 |
-
self.
|
| 300 |
con_strength = int((1-con_strength)*50)
|
| 301 |
if fix_sample == 'True':
|
| 302 |
seed_everything(42)
|
|
|
|
| 104 |
self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
|
| 105 |
self.config.model.params.cond_stage_config.params.device = device
|
| 106 |
self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
|
| 107 |
+
self.current_base = 'sd-v1-4.ckpt'
|
|
|
|
| 108 |
self.sampler = PLMSSampler(self.base_model)
|
| 109 |
|
| 110 |
# sketch part
|
|
|
|
| 143 |
|
| 144 |
@torch.no_grad()
|
| 145 |
def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
| 146 |
+
if self.current_base != base_model:
|
| 147 |
ckpt = os.path.join("models", base_model)
|
| 148 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 149 |
if "state_dict" in pl_sd:
|
|
|
|
| 151 |
else:
|
| 152 |
sd = pl_sd
|
| 153 |
self.base_model.load_state_dict(sd, strict=False)
|
| 154 |
+
self.current_base = base_model
|
| 155 |
# del sd
|
| 156 |
# del pl_sd
|
| 157 |
con_strength = int((1-con_strength)*50)
|
|
|
|
| 217 |
|
| 218 |
@torch.no_grad()
|
| 219 |
def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
| 220 |
+
if self.current_base != base_model:
|
| 221 |
ckpt = os.path.join("models", base_model)
|
| 222 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 223 |
if "state_dict" in pl_sd:
|
|
|
|
| 225 |
else:
|
| 226 |
sd = pl_sd
|
| 227 |
self.base_model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
|
| 228 |
+
self.current_base = base_model
|
| 229 |
con_strength = int((1-con_strength)*50)
|
| 230 |
if fix_sample == 'True':
|
| 231 |
seed_everything(42)
|
|
|
|
| 287 |
|
| 288 |
@torch.no_grad()
|
| 289 |
def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
|
| 290 |
+
if self.current_base != base_model:
|
| 291 |
ckpt = os.path.join("models", base_model)
|
| 292 |
pl_sd = torch.load(ckpt, map_location="cpu")
|
| 293 |
if "state_dict" in pl_sd:
|
|
|
|
| 295 |
else:
|
| 296 |
sd = pl_sd
|
| 297 |
self.base_model.load_state_dict(sd, strict=False)
|
| 298 |
+
self.current_base = base_model
|
| 299 |
con_strength = int((1-con_strength)*50)
|
| 300 |
if fix_sample == 'True':
|
| 301 |
seed_everything(42)
|