zaydzuhri's picture
Add files using upload-large-folder tool
0298ad2 verified
raw
history blame
1.13 kB
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import argparse
from pathlib import Path
import torch
import torch.distributed.checkpoint as DCP
from transformers import AutoModelForCausalLM
import fla # noqa
from torchtitan.tools.logging import init_logger, logger
@torch.inference_mode()
def convert_hf_weights(model: str, checkpoint: str):
logger.info(f"Loading model from {model}")
model = AutoModelForCausalLM.from_pretrained(model)
state_dict = model.state_dict()
logger.info(f"Writing to DCP at '{checkpoint}'")
checkpoint.mkdir(parents=True, exist_ok=True)
storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
DCP.save({"model": state_dict}, storage_writer=storage_writer)
if __name__ == "__main__":
init_logger()
parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.")
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--checkpoint", type=Path, required=True)
args = parser.parse_args()
convert_hf_weights(args.model, args.checkpoint)