import torch | |
import sys | |
if __name__ == "__main__": | |
inpath = sys.argv[1] | |
outpath = sys.argv[2] | |
submodel = "cond_stage_model" | |
if len(sys.argv) > 3: | |
submodel = sys.argv[3] | |
print("Extracting {} from {} to {}.".format(submodel, inpath, outpath)) | |
sd = torch.load(inpath, map_location="cpu") | |
new_sd = {"state_dict": dict((k.split(".", 1)[-1],v) | |
for k,v in sd["state_dict"].items() | |
if k.startswith("cond_stage_model"))} | |
torch.save(new_sd, outpath) | |