Spaces:
Build error
Build error
import os, sys | |
sys.path.insert(0, os.getcwd()) | |
import argparse | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"base_model", help="The model you want to merge with loha", default="", type=str | |
) | |
parser.add_argument( | |
"lycoris_model", | |
help="the lyco model you want to merge into sd model", | |
default="", | |
type=str, | |
) | |
parser.add_argument( | |
"output_name", help="the output model", default="./out.pt", type=str | |
) | |
parser.add_argument( | |
"--is_v2", | |
help="Your base model is sd v2 or not", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--is_sdxl", | |
help="Your base/db model is sdxl or not", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--device", | |
help="Which device you want to use to merge the weight", | |
default="cpu", | |
type=str, | |
) | |
parser.add_argument("--dtype", help="dtype to save", default="float", type=str) | |
parser.add_argument( | |
"--weight", help="weight for the lyco model to merge", default="1.0", type=float | |
) | |
return parser.parse_args() | |
args = ARGS = get_args() | |
from lycoris.utils import merge | |
from lycoris.kohya.model_utils import ( | |
load_models_from_stable_diffusion_checkpoint, | |
save_stable_diffusion_checkpoint, | |
load_file, | |
) | |
from lycoris.kohya.sdxl_model_util import ( | |
load_models_from_sdxl_checkpoint, | |
save_stable_diffusion_checkpoint as save_sdxl_checkpoint, | |
) | |
import torch | |
def main(): | |
if args.is_sdxl: | |
base = load_models_from_sdxl_checkpoint( | |
None, args.base_model, map_location=args.device | |
) | |
else: | |
base = load_models_from_stable_diffusion_checkpoint(args.is_v2, args.base_model) | |
if ARGS.lycoris_model.rsplit(".", 1)[-1] == "safetensors": | |
lyco = load_file(ARGS.lycoris_model) | |
else: | |
lyco = torch.load(ARGS.lycoris_model) | |
dtype_str = ARGS.dtype.replace("fp", "float").replace("bf", "bfloat") | |
dtype = { | |
"float": torch.float, | |
"float16": torch.float16, | |
"float32": torch.float32, | |
"float64": torch.float64, | |
"bfloat": torch.bfloat16, | |
"bfloat16": torch.bfloat16, | |
}.get(dtype_str, None) | |
if dtype is None: | |
raise ValueError(f'Cannot Find the dtype "{dtype}"') | |
if args.is_sdxl: | |
base_tes = [base[0], base[1]] | |
base_unet = base[3] | |
else: | |
base_tes = [base[0]] | |
base_unet = base[2] | |
merge(base_tes, base_unet, lyco, ARGS.weight, ARGS.device) | |
if args.is_sdxl: | |
save_sdxl_checkpoint( | |
ARGS.output_name, | |
base[0].cpu(), | |
base[1].cpu(), | |
base[3].cpu(), | |
0, | |
0, | |
None, | |
base[2], | |
getattr(base[1], "logit_scale", None), | |
dtype, | |
) | |
else: | |
save_stable_diffusion_checkpoint( | |
ARGS.is_v2, | |
ARGS.output_name, | |
base[0].cpu(), | |
base[2].cpu(), | |
None, | |
0, | |
0, | |
dtype, | |
base[1], | |
) | |
if __name__ == "__main__": | |
main() | |