File size: 1,130 Bytes
0298ad2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import argparse
from pathlib import Path
import torch
import torch.distributed.checkpoint as DCP
from transformers import AutoModelForCausalLM
import fla # noqa
from torchtitan.tools.logging import init_logger, logger
@torch.inference_mode()
def convert_hf_weights(model: str, checkpoint: str):
logger.info(f"Loading model from {model}")
model = AutoModelForCausalLM.from_pretrained(model)
state_dict = model.state_dict()
logger.info(f"Writing to DCP at '{checkpoint}'")
checkpoint.mkdir(parents=True, exist_ok=True)
storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
DCP.save({"model": state_dict}, storage_writer=storage_writer)
if __name__ == "__main__":
init_logger()
parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--checkpoint", type=Path, required=True)
args = parser.parse_args()
convert_hf_weights(args.model, args.checkpoint)
|