|
|
|
|
|
|
|
import argparse |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.distributed.checkpoint as DCP |
|
from transformers import AutoModelForCausalLM |
|
|
|
import fla |
|
from torchtitan.tools.logging import init_logger, logger |
|
|
|
|
|
@torch.inference_mode() |
|
def convert_hf_weights(model: str, checkpoint: str): |
|
logger.info(f"Loading model from {model}") |
|
model = AutoModelForCausalLM.from_pretrained(model) |
|
state_dict = model.state_dict() |
|
|
|
logger.info(f"Writing to DCP at '{checkpoint}'") |
|
checkpoint.mkdir(parents=True, exist_ok=True) |
|
storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8) |
|
DCP.save({"model": state_dict}, storage_writer=storage_writer) |
|
|
|
|
|
if __name__ == "__main__": |
|
init_logger() |
|
parser = argparse.ArgumentParser(description="Convert huggingface-style model weights to DCP format.") |
|
parser.add_argument("--model", type=str, required=True) |
|
parser.add_argument("--checkpoint", type=Path, required=True) |
|
args = parser.parse_args() |
|
|
|
convert_hf_weights(args.model, args.checkpoint) |
|
|