#!/usr/bin/env python3 """ Mistral Model Transformer This script transforms Mistral-Small-3.1-24B-Base-2503 into a text-only model by: 1. Removing multimodality features 2. Removing the vision encoder 3. Changing the architecture from "mistral3" to "mistral" 4. Ensuring weight mapping structure matches Devstral-Small-2505 exactly Usage: python convert.py --input-model mistralai/Mistral-Small-3.1-24B-Base-2503 --output-path ./mistral-small-text-only --reference-model mistralai/Devstral-Small-2505 Note: This script requires significant disk space to download and process the full model. """ import argparse import json import os import shutil from pathlib import Path import logging from huggingface_hub import snapshot_download, hf_hub_download from safetensors.torch import load_file, save_file from transformers import AutoConfig, AutoModelForCausalLM logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description="Transform Mistral model to text-only version") parser.add_argument( "--input-model", type=str, default="mistralai/Mistral-Small-3.1-24B-Base-2503", help="Path or HF repo id of the input model" ) parser.add_argument( "--output-path", type=str, required=True, help="Path to save the transformed model" ) parser.add_argument( "--cache-dir", type=str, default=None, help="Cache directory for downloading models" ) parser.add_argument( "--reference-model", type=str, default="mistralai/Devstral-Small-2505", help="Path or HF repo id of the reference model for weight mapping" ) return parser.parse_args() def transform_config(config_path, output_path, reference_config=None): """ Transform the model config by: 1. Changing model_type from "mistral3" to "mistral" 2. Removing vision_config 3. Removing multimodal parameters 4. Updating architectures to match Devstral exactly 5. Ensuring all parameters match Devstral's config exactly """ logger.info(f"Transforming config at {config_path}") with open(config_path, "r") as f: config = json.load(f) if reference_config: logger.info("Using reference config as template") new_config = reference_config.copy() text_config = config.get("text_config", config) for key, value in text_config.items(): if key not in new_config and key != "model_type": new_config[key] = value logger.info(f"Added parameter from original config: {key}") else: logger.info("No reference config available, using basic transformation") new_config = config.copy() # Change model_type from mistral3 to mistral if new_config.get("model_type") == "mistral3": new_config["model_type"] = "mistral" logger.info("Changed model_type from 'mistral3' to 'mistral'") # Update architectures to use MistralForCausalLM if "architectures" in new_config: new_config["architectures"] = ["MistralForCausalLM"] logger.info("Changed architecture to 'MistralForCausalLM'") # Remove vision_config if "vision_config" in new_config: del new_config["vision_config"] logger.info("Removed vision_config") # Remove multimodal-related parameters multimodal_params = [ "image_token_index", "multimodal_projector_bias", "projector_hidden_act", "spatial_merge_size", "vision_tower_layer_list", "vision_feature_layer" ] for param in multimodal_params: if param in new_config: del new_config[param] logger.info(f"Removed multimodal parameter: {param}") if "text_config" in new_config: text_config = new_config.pop("text_config") for key, value in text_config.items(): if key != "model_type": # Don't overwrite the model_type new_config[key] = value logger.info("Moved text_config parameters to top level") if "bos_token_id" not in new_config: new_config["bos_token_id"] = 1 logger.info("Added bos_token_id: 1") if "eos_token_id" not in new_config: new_config["eos_token_id"] = 2 logger.info("Added eos_token_id: 2") if "tie_word_embeddings" not in new_config: new_config["tie_word_embeddings"] = False logger.info("Added tie_word_embeddings: false") new_config["transformers_version"] = "4.51.3" logger.info("Updated transformers_version to 4.51.3") os_output_path = Path(output_path) / "config.json" with open(os_output_path, "w") as f: json.dump(new_config, f, indent=2) logger.info(f"Saved transformed config to {os_output_path}") return new_config def is_vision_weight(weight_name): """Check if a weight is related to vision functionality""" vision_patterns = ["vision_tower", "multi_modal_projector"] return any(pattern in weight_name for pattern in vision_patterns) def transform_weights(model_path, output_path, safetensors_index_path, reference_weight_map=None): """ Transform model weights by: 1. Loading the weight map from safetensors index 2. Filtering out vision-related weights 3. Removing the "language_model." prefix from weight names 4. Ensuring the exact same partitioning as Devstral 5. Saving the filtered weights to the output path """ logger.info(f"Transforming weights using index at {safetensors_index_path}") with open(safetensors_index_path, "r") as f: index_data = json.load(f) original_weight_map = index_data.get("weight_map", {}) # Count vision and non-vision weights vision_weights = [name for name in original_weight_map if is_vision_weight(name)] non_vision_weights = [name for name in original_weight_map if not is_vision_weight(name)] logger.info(f"Found {len(vision_weights)} vision-related weights to remove") logger.info(f"Found {len(non_vision_weights)} non-vision weights to keep") # Create a mapping from original weight names to Devstral-style weight names weight_name_mapping = {} for original_name in non_vision_weights: if original_name.startswith("language_model."): new_name = original_name[len("language_model."):] weight_name_mapping[original_name] = new_name else: weight_name_mapping[original_name] = original_name logger.info(f"Created mapping for {len(weight_name_mapping)} weight names") new_weight_map = {} if reference_weight_map and "weight_map" in reference_weight_map: devstral_weight_map = reference_weight_map["weight_map"] logger.info(f"Using Devstral reference weight map with {len(devstral_weight_map)} entries") for original_name, new_name in weight_name_mapping.items(): if new_name in devstral_weight_map: new_weight_map[new_name] = devstral_weight_map[new_name] else: logger.warning(f"Weight {new_name} not found in Devstral reference map") else: logger.warning("No Devstral reference map available, using original partitioning") for original_name, new_name in weight_name_mapping.items(): new_weight_map[new_name] = original_weight_map[original_name] # Group weights by their safetensor file for the actual transformation file_to_weights = {} for new_name, file_name in new_weight_map.items(): if file_name not in file_to_weights: file_to_weights[file_name] = [] original_names = [orig for orig, new in weight_name_mapping.items() if new == new_name] if original_names: file_to_weights[file_name].append((original_names[0], new_name)) os.makedirs(Path(output_path), exist_ok=True) # Process each safetensor file for file_name, weight_pairs in file_to_weights.items(): logger.info(f"Processing {file_name} with {len(weight_pairs)} weights") tensors_to_save = {} for original_name, new_name in weight_pairs: original_file = original_weight_map.get(original_name) if not original_file: logger.warning(f"Original file not found for weight {original_name}") continue input_file_path = Path(model_path) / original_file if not input_file_path.exists(): logger.warning(f"File {input_file_path} does not exist, skipping") continue try: original_tensors = load_file(input_file_path) if original_name in original_tensors: tensors_to_save[new_name] = original_tensors[original_name] else: logger.warning(f"Weight {original_name} not found in {original_file}") except Exception as e: logger.error(f"Error loading {original_file}: {e}") if tensors_to_save: output_file_path = Path(output_path) / file_name try: save_file(tensors_to_save, output_file_path) logger.info(f"Saved {len(tensors_to_save)} weights to {file_name}") except Exception as e: logger.error(f"Error saving {file_name}: {e}") # Save the new safetensors index new_index = { "metadata": {"total_size": reference_weight_map.get("metadata", {}).get("total_size", 0)} if reference_weight_map else index_data.get("metadata", {}), "weight_map": new_weight_map } output_index_path = Path(output_path) / "model.safetensors.index.json" with open(output_index_path, "w") as f: json.dump(new_index, f, indent=2) logger.info(f"Saved transformed safetensors index to {output_index_path}") def copy_additional_files(model_path, output_path): """Copy additional model files like tokenizer, generation config, etc.""" additional_files = [ "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "generation_config.json" ] for filename in additional_files: src_path = Path(model_path) / filename if src_path.exists(): dst_path = Path(output_path) / filename shutil.copy(src_path, dst_path) logger.info(f"Copied {filename} to output directory") else: logger.warning(f"File {filename} not found in model directory") def download_minimal_files(repo_id, output_dir, cache_dir=None): """Download only the necessary files for transformation without the full model""" logger.info(f"Downloading minimal files from {repo_id}") # List of files to download files_to_download = [ "config.json", "model.safetensors.index.json", "tokenizer_config.json", "special_tokens_map.json", "generation_config.json" ] downloaded_files = {} for filename in files_to_download: try: file_path = hf_hub_download( repo_id=repo_id, filename=filename, cache_dir=cache_dir, local_files_only=False ) downloaded_files[filename] = file_path logger.info(f"Downloaded {filename} to {file_path}") except Exception as e: logger.warning(f"Failed to download {filename}: {e}") return downloaded_files def download_reference_weight_map(reference_model, cache_dir=None): """Download reference model's weight map to use as a reference""" logger.info(f"Downloading reference weight map from {reference_model}") try: file_path = hf_hub_download( repo_id=reference_model, filename="model.safetensors.index.json", cache_dir=cache_dir, local_files_only=False ) with open(file_path, "r") as f: reference_map = json.load(f) logger.info(f"Successfully loaded reference weight map with {len(reference_map.get('weight_map', {}))} weights") return reference_map except Exception as e: logger.error(f"Failed to download reference weight map: {e}") return None def download_reference_config(reference_model, cache_dir=None): """Download reference model's config.json to use as a reference""" logger.info(f"Downloading reference config from {reference_model}") try: file_path = hf_hub_download( repo_id=reference_model, filename="config.json", cache_dir=cache_dir, local_files_only=False ) with open(file_path, "r") as f: reference_config = json.load(f) logger.info(f"Successfully loaded reference config") return reference_config except Exception as e: logger.error(f"Failed to download reference config: {e}") return None def verify_model(output_path): """Verify that the transformed model can be loaded without errors""" logger.info(f"Verifying transformed model at {output_path}") try: config = AutoConfig.from_pretrained(output_path) logger.info(f"Successfully loaded config with model_type={config.model_type}") # Attempt to load just the model architecture (without weights) # This verifies the configuration is valid AutoModelForCausalLM.from_config(config) logger.info("Successfully loaded model architecture from config") return True except Exception as e: logger.error(f"Error verifying model: {e}") return False def main(): args = parse_args() input_model = args.input_model output_path = args.output_path cache_dir = args.cache_dir reference_model = args.reference_model # Download reference weight map and config reference_weight_map = download_reference_weight_map(reference_model, cache_dir) if not reference_weight_map: logger.warning("Could not download reference weight map. The weight partitioning may not match exactly.") reference_config = download_reference_config(reference_model, cache_dir) if not reference_config: logger.warning("Could not download reference config. The config may not match exactly.") # Create output directory os.makedirs(output_path, exist_ok=True) # Download the full model if not os.path.exists(input_model) or not os.path.isdir(input_model): logger.info(f"Downloading model from {input_model}") try: model_path = snapshot_download( repo_id=input_model, cache_dir=cache_dir, local_files_only=False, ignore_patterns=["*consolidated*"] ) except Exception as e: logger.error(f"Error downloading model: {e}") return else: model_path = input_model logger.info(f"Model path: {model_path}") # Transform config config_path = os.path.join(model_path, "config.json") transform_config(config_path, output_path, reference_config) # Transform weights safetensors_index_path = os.path.join(model_path, "model.safetensors.index.json") transform_weights( model_path, output_path, safetensors_index_path, reference_weight_map=reference_weight_map ) # Copy additional files copy_additional_files(model_path, output_path) # Verify the transformed model success = verify_model(output_path) if success: logger.info(f"Successfully transformed model to {output_path}") else: logger.error(f"Failed to transform model properly") if __name__ == "__main__": main()