fffiloni commited on
Commit
de57e3e
1 Parent(s): cc91cd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -0
app.py CHANGED
@@ -106,6 +106,24 @@ models_b = WurstCoreB.Models(
106
  )
107
  models_b.generator.bfloat16().eval().requires_grad_(False)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def infer(ref_style_file, style_description, caption):
110
  global models_rbm, models_b, device
111
  if low_vram:
 
106
  )
107
  models_b.generator.bfloat16().eval().requires_grad_(False)
108
 
109
+ if low_vram:
110
+ # Off-load old generator (which is not used in models_rbm)
111
+ models.generator.to("cpu")
112
+ torch.cuda.empty_cache()
113
+
114
+ generator_rbm = StageCRBM()
115
+ for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
116
+ set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
117
+ generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
118
+ generator_rbm = core.load_model(generator_rbm, 'generator')
119
+
120
+ models_rbm = core.Models(
121
+ effnet=models.effnet, previewer=models.previewer,
122
+ generator=generator_rbm, generator_ema=models.generator_ema,
123
+ tokenizer=models.tokenizer, text_model=models.text_model, image_model=models.image_model
124
+ )
125
+ models_rbm.generator.eval().requires_grad_(False)
126
+
127
  def infer(ref_style_file, style_description, caption):
128
  global models_rbm, models_b, device
129
  if low_vram: