import argparse import os.path import torch from internvl.model.internvl_chat import InternVLChatModel argparse = argparse.ArgumentParser() argparse.add_argument('model_path', type=str, default='') argparse.add_argument('output_path', type=str, default='') args = argparse.parse_args() model = InternVLChatModel.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) model = model.mlp1.to(torch.bfloat16) ckpt = model.state_dict() output_path = os.path.join(args.output_path, 'mlp_projector.pth') torch.save(ckpt, output_path) print('finished')