Spaces:
Runtime error
Runtime error
Use mixed precision
Browse files
model.py
CHANGED
|
@@ -57,9 +57,13 @@ class Model:
|
|
| 57 |
if base_model_id == self.base_model_id and task_name == self.task_name:
|
| 58 |
return self.pipe
|
| 59 |
model_id = CONTROLNET_MODEL_IDS[task_name]
|
| 60 |
-
controlnet = ControlNetModel.from_pretrained(model_id
|
|
|
|
| 61 |
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 62 |
-
base_model_id,
|
|
|
|
|
|
|
|
|
|
| 63 |
pipe.scheduler = UniPCMultistepScheduler.from_config(
|
| 64 |
pipe.scheduler.config)
|
| 65 |
pipe.enable_xformers_memory_efficient_attention()
|
|
@@ -89,7 +93,9 @@ class Model:
|
|
| 89 |
torch.cuda.empty_cache()
|
| 90 |
gc.collect()
|
| 91 |
model_id = CONTROLNET_MODEL_IDS[task_name]
|
| 92 |
-
controlnet = ControlNetModel.from_pretrained(model_id
|
|
|
|
|
|
|
| 93 |
torch.cuda.empty_cache()
|
| 94 |
gc.collect()
|
| 95 |
self.pipe.controlnet = controlnet
|
|
@@ -102,6 +108,7 @@ class Model:
|
|
| 102 |
prompt = f'{prompt}, {additional_prompt}'
|
| 103 |
return prompt
|
| 104 |
|
|
|
|
| 105 |
def run_pipe(
|
| 106 |
self,
|
| 107 |
prompt: str,
|
|
|
|
| 57 |
if base_model_id == self.base_model_id and task_name == self.task_name:
|
| 58 |
return self.pipe
|
| 59 |
model_id = CONTROLNET_MODEL_IDS[task_name]
|
| 60 |
+
controlnet = ControlNetModel.from_pretrained(model_id,
|
| 61 |
+
torch_dtype=torch.float16)
|
| 62 |
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 63 |
+
base_model_id,
|
| 64 |
+
safety_checker=None,
|
| 65 |
+
controlnet=controlnet,
|
| 66 |
+
torch_dtype=torch.float16)
|
| 67 |
pipe.scheduler = UniPCMultistepScheduler.from_config(
|
| 68 |
pipe.scheduler.config)
|
| 69 |
pipe.enable_xformers_memory_efficient_attention()
|
|
|
|
| 93 |
torch.cuda.empty_cache()
|
| 94 |
gc.collect()
|
| 95 |
model_id = CONTROLNET_MODEL_IDS[task_name]
|
| 96 |
+
controlnet = ControlNetModel.from_pretrained(model_id,
|
| 97 |
+
torch_dtype=torch.float16)
|
| 98 |
+
controlnet.to(self.device)
|
| 99 |
torch.cuda.empty_cache()
|
| 100 |
gc.collect()
|
| 101 |
self.pipe.controlnet = controlnet
|
|
|
|
| 108 |
prompt = f'{prompt}, {additional_prompt}'
|
| 109 |
return prompt
|
| 110 |
|
| 111 |
+
@torch.autocast('cuda')
|
| 112 |
def run_pipe(
|
| 113 |
self,
|
| 114 |
prompt: str,
|