import argparse import logging import os import torch from hf_molmo.config_molmo import MolmoConfig from hf_molmo.image_preprocessing_molmo import MolmoImageProcessor from hf_molmo.modelling_molmo import MOLMoForCausalLM from hf_molmo.preprocessing_molmo import MolmoProcessor from olmo import ModelConfig from olmo.mm_data.data_utils import build_tokenizer logger = logging.getLogger(__name__) def write_config(checkpoint_dir: str, output_dir: str): # save config as HF config logger.info(f"Loading checkpoint from {checkpoint_dir}") config_path = os.path.join(checkpoint_dir, "config.yaml") model_config = ModelConfig.load(config_path, key="model") config_kwargs = model_config.asdict() config_kwargs["use_cache"] = True config_kwargs["vit_load_path"] = None config_kwargs["llm_load_path"] = None config = MolmoConfig( vocab_size=model_config.vocab_size, embedding_size=model_config.embedding_size, hidden_size=model_config.d_model, intermediate_size=model_config.mlp_hidden_size, num_hidden_layers=model_config.n_layers, num_attention_heads=model_config.n_heads, num_key_value_heads=model_config.n_kv_heads, max_position_embeddings=model_config.max_position_embeddings or model_config.max_sequence_length, initializer_range=model_config.initializer_range, use_cache=True, layer_norm_eps=model_config.layer_norm_eps, rope_theta=model_config.rope_theta, clip_qkv=model_config.clip_qkv, qkv_bias=model_config.qkv_bias, weight_tying=model_config.weight_tying, use_position_ids=True, tie_word_embeddings=False ) logger.info(f"Saving HF-compatible config to {os.path.join(checkpoint_dir, 'config.json')}") config.save_pretrained(output_dir) preprocessor = MolmoProcessor( MolmoImageProcessor( max_crops=model_config.max_crops ), # FIXME now just assumes everything if fixed build_tokenizer(model_config.tokenizer.identifier.split("m:")[1]).tokenizer ) preprocessor.save_pretrained(output_dir) def write_model(checkpoint_dir: str, output_dir: str, ignore_olmo_compatibility: bool = False): # For device_map = "auto", etc. the models are loaded in a way that start_prefix is not computed correctly. # So, we explicitly store the model with the expected prefix. old_model_path = os.path.join(checkpoint_dir, "model.pt") new_model_path = os.path.join(output_dir, "pytorch_model.bin") state_dict = torch.load(old_model_path) new_state_dict = {f"{MOLMoForCausalLM.base_model_prefix}.{key}": val for key, val in state_dict.items()} torch.save(new_state_dict, new_model_path) def convert_checkpoint(checkpoint_dir: str, output_dir: str): os.makedirs(output_dir, exist_ok=True) write_config(checkpoint_dir, output_dir) write_model(checkpoint_dir, output_dir) def main(): parser = argparse.ArgumentParser( description="Adds a config.json to the checkpoint directory, and creates pytorch_model.bin, " "making it easier to load weights as HF models." ) parser.add_argument("checkpoint_dir") parser.add_argument("output_dir") args = parser.parse_args() convert_checkpoint(args.checkpoint_dir, args.output_dir) if __name__ == "__main__": main()