Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -13,6 +13,10 @@ import time
|
|
| 13 |
from autoregressive.models.generate import generate
|
| 14 |
from condition.midas.depth import MidasDetector
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
models = {
|
| 17 |
"canny": "checkpoints/canny_MR.safetensors",
|
| 18 |
"depth": "checkpoints/depth_MR.safetensors",
|
|
@@ -48,7 +52,8 @@ class Model:
|
|
| 48 |
self.gpt_model_canny = self.load_gpt(condition_type='canny')
|
| 49 |
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
| 50 |
self.get_control_canny = CannyDetector()
|
| 51 |
-
self.get_control_depth = MidasDetector('cuda')
|
|
|
|
| 52 |
|
| 53 |
def to(self, device):
|
| 54 |
self.gpt_model_canny.to('cuda')
|
|
@@ -196,11 +201,18 @@ class Model:
|
|
| 196 |
# self.get_control_depth.model.to(self.device)
|
| 197 |
# self.vq_model.to(self.device)
|
| 198 |
image_tensor = torch.from_numpy(np.array(image)).to(self.device)
|
| 199 |
-
condition_img = torch.from_numpy(
|
| 200 |
-
|
| 201 |
-
condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1)
|
| 202 |
-
condition_img = condition_img.to(self.device)
|
| 203 |
-
condition_img = 2 * (condition_img / 255 - 0.5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
prompts = [prompt] * 2
|
| 205 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
| 206 |
|
|
|
|
| 13 |
from autoregressive.models.generate import generate
|
| 14 |
from condition.midas.depth import MidasDetector
|
| 15 |
|
| 16 |
+
from controlnet_aux import (
|
| 17 |
+
MidasDetector,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
models = {
|
| 21 |
"canny": "checkpoints/canny_MR.safetensors",
|
| 22 |
"depth": "checkpoints/depth_MR.safetensors",
|
|
|
|
| 52 |
self.gpt_model_canny = self.load_gpt(condition_type='canny')
|
| 53 |
self.gpt_model_depth = self.load_gpt(condition_type='depth')
|
| 54 |
self.get_control_canny = CannyDetector()
|
| 55 |
+
# self.get_control_depth = MidasDetector('cuda')
|
| 56 |
+
self.get_control_depth = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
| 57 |
|
| 58 |
def to(self, device):
|
| 59 |
self.gpt_model_canny.to('cuda')
|
|
|
|
| 201 |
# self.get_control_depth.model.to(self.device)
|
| 202 |
# self.vq_model.to(self.device)
|
| 203 |
image_tensor = torch.from_numpy(np.array(image)).to(self.device)
|
| 204 |
+
# condition_img = torch.from_numpy(
|
| 205 |
+
# self.get_control_depth(image_tensor)).unsqueeze(0)
|
| 206 |
+
# condition_img = condition_img.unsqueeze(0).repeat(2, 3, 1, 1)
|
| 207 |
+
# condition_img = condition_img.to(self.device)
|
| 208 |
+
# condition_img = 2 * (condition_img / 255 - 0.5)
|
| 209 |
+
|
| 210 |
+
control_image = self.get_control_depth(
|
| 211 |
+
image=image,
|
| 212 |
+
image_resolution=512,
|
| 213 |
+
detect_resolution=512,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
prompts = [prompt] * 2
|
| 217 |
caption_embs, emb_masks = self.t5_model.get_text_embeddings(prompts)
|
| 218 |
|