Spaces:
Runtime error
Runtime error
| ''' | |
| Converts a transformers model to a format compatible with flexgen. | |
| ''' | |
| import argparse | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) | |
| parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") | |
| args = parser.parse_args() | |
| def disable_torch_init(): | |
| """ | |
| Disable the redundant torch default initialization to accelerate model creation. | |
| """ | |
| import torch | |
| global torch_linear_init_backup | |
| global torch_layer_norm_init_backup | |
| torch_linear_init_backup = torch.nn.Linear.reset_parameters | |
| setattr(torch.nn.Linear, "reset_parameters", lambda self: None) | |
| torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters | |
| setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) | |
| def restore_torch_init(): | |
| """Rollback the change made by disable_torch_init.""" | |
| import torch | |
| setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) | |
| setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) | |
| if __name__ == '__main__': | |
| path = Path(args.MODEL) | |
| model_name = path.name | |
| print(f"Loading {model_name}...") | |
| # disable_torch_init() | |
| model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True) | |
| # restore_torch_init() | |
| tokenizer = AutoTokenizer.from_pretrained(path) | |
| out_folder = Path(f"models/{model_name}-np") | |
| if not Path(out_folder).exists(): | |
| os.mkdir(out_folder) | |
| print(f"Saving the converted model to {out_folder}...") | |
| for name, param in tqdm(list(model.model.named_parameters())): | |
| name = name.replace("decoder.final_layer_norm", "decoder.layer_norm") | |
| param_path = os.path.join(out_folder, name) | |
| with open(param_path, "wb") as f: | |
| np.save(f, param.cpu().detach().numpy()) | |