add safetensors
Browse files
vae.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
# Adopted from LDM's KL-VAE: https://github.com/CompVis/latent-diffusion
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
|
|
| 4 |
|
| 5 |
import numpy as np
|
| 6 |
|
|
@@ -461,7 +462,7 @@ class AutoencoderKL(nn.Module):
|
|
| 461 |
self.init_from_ckpt(ckpt_path)
|
| 462 |
|
| 463 |
def init_from_ckpt(self, path):
|
| 464 |
-
sd =
|
| 465 |
msg = self.load_state_dict(sd, strict=False)
|
| 466 |
print("Loading pre-trained KL-VAE")
|
| 467 |
print("Missing keys:")
|
|
|
|
| 1 |
# Adopted from LDM's KL-VAE: https://github.com/CompVis/latent-diffusion
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
|
|
|
|
| 462 |
self.init_from_ckpt(ckpt_path)
|
| 463 |
|
| 464 |
def init_from_ckpt(self, path):
|
| 465 |
+
sd = load_file(path)
|
| 466 |
msg = self.load_state_dict(sd, strict=False)
|
| 467 |
print("Loading pre-trained KL-VAE")
|
| 468 |
print("Missing keys:")
|