File size: 3,374 Bytes
21ac790
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import argparse
import logging
import os

import torch

from hf_molmo.config_molmo import MolmoConfig
from hf_molmo.image_preprocessing_molmo import MolmoImageProcessor
from hf_molmo.modelling_molmo import MOLMoForCausalLM
from hf_molmo.preprocessing_molmo import MolmoProcessor
from olmo import ModelConfig
from olmo.mm_data.data_utils import build_tokenizer

logger = logging.getLogger(__name__)


def write_config(checkpoint_dir: str, output_dir: str):
    # save config as HF config

    logger.info(f"Loading checkpoint from {checkpoint_dir}")

    config_path = os.path.join(checkpoint_dir, "config.yaml")
    model_config = ModelConfig.load(config_path, key="model")
    config_kwargs = model_config.asdict()
    config_kwargs["use_cache"] = True
    config_kwargs["vit_load_path"] = None
    config_kwargs["llm_load_path"] = None
    config = MolmoConfig(
        vocab_size=model_config.vocab_size,
        embedding_size=model_config.embedding_size,
        hidden_size=model_config.d_model,
        intermediate_size=model_config.mlp_hidden_size,
        num_hidden_layers=model_config.n_layers,
        num_attention_heads=model_config.n_heads,
        num_key_value_heads=model_config.n_kv_heads,
        max_position_embeddings=model_config.max_position_embeddings or model_config.max_sequence_length,
        initializer_range=model_config.initializer_range,
        use_cache=True,
        layer_norm_eps=model_config.layer_norm_eps,
        rope_theta=model_config.rope_theta,
        clip_qkv=model_config.clip_qkv,
        qkv_bias=model_config.qkv_bias,
        weight_tying=model_config.weight_tying,
        use_position_ids=True,
        tie_word_embeddings=False
    )

    logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}")
    config.save_pretrained(output_dir)

    preprocessor = MolmoProcessor(
        MolmoImageProcessor(
            max_crops=model_config.max_crops
        ),  # FIXME now just assumes everything if fixed
        build_tokenizer(model_config.tokenizer.identifier.split("m:")[1]).tokenizer
    )
    preprocessor.save_pretrained(output_dir)


def write_model(checkpoint_dir: str, output_dir: str, ignore_olmo_compatibility: bool = False):
    # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly.
    # So, we explicitly store the model with the expected prefix.
    old_model_path = os.path.join(checkpoint_dir, "model.pt")
    new_model_path = os.path.join(output_dir, "pytorch_model.bin")

    state_dict = torch.load(old_model_path)
    new_state_dict = {f"{MOLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()}
    torch.save(new_state_dict, new_model_path)


def convert_checkpoint(checkpoint_dir: str, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    write_config(checkpoint_dir, output_dir)
    write_model(checkpoint_dir, output_dir)


def main():
    parser = argparse.ArgumentParser(
        description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, "
                    "making it easier to load weights as HF models."
    )
    parser.add_argument("checkpoint_dir")
    parser.add_argument("output_dir")
    args = parser.parse_args()
    convert_checkpoint(args.checkpoint_dir, args.output_dir)


if __name__ == "__main__":
    main()