Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| from safetensors import safe_open | |
| from safetensors.torch import save_file | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--asylora_path', type=str, required=True, help="Path to the input asylora file.") | |
| parser.add_argument('--output_path', type=str, required=True, help="Path to save the modified safetensors file.") | |
| parser.add_argument('--lora_up', type=int, required=True, help="The target lora_up value.") | |
| args = parser.parse_args() | |
| output_dir = os.path.dirname(args.output_path) | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| with safe_open(args.asylora_path, framework="pt") as f: | |
| tensor_dict = {key: f.get_tensor(key) for key in f.keys()} | |
| modified_dict = {} | |
| for key, tensor in tensor_dict.items(): | |
| if 'lora_ups' in key: | |
| lora_up_index = int(key.split('.')[2]) | |
| if lora_up_index != args.lora_up - 1: | |
| continue | |
| else: | |
| new_key = key.replace(f'lora_ups.{lora_up_index}.', 'lora_up.') | |
| modified_dict[new_key] = tensor | |
| else: | |
| modified_dict[key] = tensor | |
| save_file(modified_dict, args.output_path) | |
| if __name__ == "__main__": | |
| main() | |