Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -106,6 +106,24 @@ models_b = WurstCoreB.Models(
|
|
| 106 |
)
|
| 107 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def infer(ref_style_file, style_description, caption):
|
| 110 |
global models_rbm, models_b, device
|
| 111 |
if low_vram:
|
|
|
|
| 106 |
)
|
| 107 |
models_b.generator.bfloat16().eval().requires_grad_(False)
|
| 108 |
|
| 109 |
+
if low_vram:
|
| 110 |
+
# Off-load old generator (which is not used in models_rbm)
|
| 111 |
+
models.generator.to("cpu")
|
| 112 |
+
torch.cuda.empty_cache()
|
| 113 |
+
|
| 114 |
+
generator_rbm = StageCRBM()
|
| 115 |
+
for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
|
| 116 |
+
set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
|
| 117 |
+
generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
|
| 118 |
+
generator_rbm = core.load_model(generator_rbm, 'generator')
|
| 119 |
+
|
| 120 |
+
models_rbm = core.Models(
|
| 121 |
+
effnet=models.effnet, previewer=models.previewer,
|
| 122 |
+
generator=generator_rbm, generator_ema=models.generator_ema,
|
| 123 |
+
tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
|
| 124 |
+
)
|
| 125 |
+
models_rbm.generator.eval().requires_grad_(False)
|
| 126 |
+
|
| 127 |
def infer(ref_style_file, style_description, caption):
|
| 128 |
global models_rbm, models_b, device
|
| 129 |
if low_vram:
|