| import argparse | |
| import os | |
| import torch | |
| def parse_args(): | |
| parser = argparse.ArgumentParser("Convert Swin Transformer to Detectron2") | |
| parser.add_argument("source_model", default="", type=str, | |
| help="Source model") | |
| parser.add_argument("output_model", default="", type=str, | |
| help="Output model") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| if os.path.splitext(args.source_model)[-1] != ".pth": | |
| raise ValueError("You should save weights as pth file") | |
| source_weights = torch.load( | |
| args.source_model, map_location=torch.device('cpu'))["model"] | |
| converted_weights = {} | |
| keys = list(source_weights.keys()) | |
| prefix = 'backbone.bottom_up.' | |
| for key in keys: | |
| converted_weights[prefix+key] = source_weights[key] | |
| torch.save(converted_weights, args.output_model) | |
| if __name__ == "__main__": | |
| main() | |