|
|
|
""" |
|
Mistral Model Transformer |
|
|
|
This script transforms Mistral-Small-3.1-24B-Base-2503 into a text-only model by: |
|
1. Removing multimodality features |
|
2. Removing the vision encoder |
|
3. Changing the architecture from "mistral3" to "mistral" |
|
4. Ensuring weight mapping structure matches Devstral-Small-2505 exactly |
|
|
|
Usage: |
|
python convert.py --input-model mistralai/Mistral-Small-3.1-24B-Base-2503 --output-path ./mistral-small-text-only --reference-model mistralai/Devstral-Small-2505 |
|
|
|
Note: |
|
This script requires significant disk space to download and process the full model. |
|
""" |
|
|
|
import argparse |
|
import json |
|
import os |
|
import shutil |
|
from pathlib import Path |
|
import logging |
|
|
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
from safetensors.torch import load_file, save_file |
|
from transformers import AutoConfig, AutoModelForCausalLM |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Transform Mistral model to text-only version") |
|
parser.add_argument( |
|
"--input-model", |
|
type=str, |
|
default="mistralai/Mistral-Small-3.1-24B-Base-2503", |
|
help="Path or HF repo id of the input model" |
|
) |
|
parser.add_argument( |
|
"--output-path", |
|
type=str, |
|
required=True, |
|
help="Path to save the transformed model" |
|
) |
|
parser.add_argument( |
|
"--cache-dir", |
|
type=str, |
|
default=None, |
|
help="Cache directory for downloading models" |
|
) |
|
parser.add_argument( |
|
"--reference-model", |
|
type=str, |
|
default="mistralai/Devstral-Small-2505", |
|
help="Path or HF repo id of the reference model for weight mapping" |
|
) |
|
return parser.parse_args() |
|
|
|
def transform_config(config_path, output_path, reference_config=None): |
|
""" |
|
Transform the model config by: |
|
1. Changing model_type from "mistral3" to "mistral" |
|
2. Removing vision_config |
|
3. Removing multimodal parameters |
|
4. Updating architectures to match Devstral exactly |
|
5. Ensuring all parameters match Devstral's config exactly |
|
""" |
|
logger.info(f"Transforming config at {config_path}") |
|
|
|
with open(config_path, "r") as f: |
|
config = json.load(f) |
|
|
|
if reference_config: |
|
logger.info("Using reference config as template") |
|
new_config = reference_config.copy() |
|
|
|
text_config = config.get("text_config", config) |
|
|
|
for key, value in text_config.items(): |
|
if key not in new_config and key != "model_type": |
|
new_config[key] = value |
|
logger.info(f"Added parameter from original config: {key}") |
|
else: |
|
logger.info("No reference config available, using basic transformation") |
|
new_config = config.copy() |
|
|
|
|
|
if new_config.get("model_type") == "mistral3": |
|
new_config["model_type"] = "mistral" |
|
logger.info("Changed model_type from 'mistral3' to 'mistral'") |
|
|
|
|
|
if "architectures" in new_config: |
|
new_config["architectures"] = ["MistralForCausalLM"] |
|
logger.info("Changed architecture to 'MistralForCausalLM'") |
|
|
|
|
|
if "vision_config" in new_config: |
|
del new_config["vision_config"] |
|
logger.info("Removed vision_config") |
|
|
|
|
|
multimodal_params = [ |
|
"image_token_index", |
|
"multimodal_projector_bias", |
|
"projector_hidden_act", |
|
"spatial_merge_size", |
|
"vision_tower_layer_list", |
|
"vision_feature_layer" |
|
] |
|
|
|
for param in multimodal_params: |
|
if param in new_config: |
|
del new_config[param] |
|
logger.info(f"Removed multimodal parameter: {param}") |
|
|
|
if "text_config" in new_config: |
|
text_config = new_config.pop("text_config") |
|
for key, value in text_config.items(): |
|
if key != "model_type": |
|
new_config[key] = value |
|
logger.info("Moved text_config parameters to top level") |
|
|
|
if "bos_token_id" not in new_config: |
|
new_config["bos_token_id"] = 1 |
|
logger.info("Added bos_token_id: 1") |
|
|
|
if "eos_token_id" not in new_config: |
|
new_config["eos_token_id"] = 2 |
|
logger.info("Added eos_token_id: 2") |
|
|
|
if "tie_word_embeddings" not in new_config: |
|
new_config["tie_word_embeddings"] = False |
|
logger.info("Added tie_word_embeddings: false") |
|
|
|
new_config["transformers_version"] = "4.51.3" |
|
logger.info("Updated transformers_version to 4.51.3") |
|
|
|
os_output_path = Path(output_path) / "config.json" |
|
with open(os_output_path, "w") as f: |
|
json.dump(new_config, f, indent=2) |
|
|
|
logger.info(f"Saved transformed config to {os_output_path}") |
|
return new_config |
|
|
|
def is_vision_weight(weight_name): |
|
"""Check if a weight is related to vision functionality""" |
|
vision_patterns = ["vision_tower", "multi_modal_projector"] |
|
return any(pattern in weight_name for pattern in vision_patterns) |
|
|
|
def transform_weights(model_path, output_path, safetensors_index_path, reference_weight_map=None): |
|
""" |
|
Transform model weights by: |
|
1. Loading the weight map from safetensors index |
|
2. Filtering out vision-related weights |
|
3. Removing the "language_model." prefix from weight names |
|
4. Ensuring the exact same partitioning as Devstral |
|
5. Saving the filtered weights to the output path |
|
""" |
|
logger.info(f"Transforming weights using index at {safetensors_index_path}") |
|
|
|
with open(safetensors_index_path, "r") as f: |
|
index_data = json.load(f) |
|
|
|
original_weight_map = index_data.get("weight_map", {}) |
|
|
|
|
|
vision_weights = [name for name in original_weight_map if is_vision_weight(name)] |
|
non_vision_weights = [name for name in original_weight_map if not is_vision_weight(name)] |
|
|
|
logger.info(f"Found {len(vision_weights)} vision-related weights to remove") |
|
logger.info(f"Found {len(non_vision_weights)} non-vision weights to keep") |
|
|
|
|
|
weight_name_mapping = {} |
|
for original_name in non_vision_weights: |
|
if original_name.startswith("language_model."): |
|
new_name = original_name[len("language_model."):] |
|
weight_name_mapping[original_name] = new_name |
|
else: |
|
weight_name_mapping[original_name] = original_name |
|
|
|
logger.info(f"Created mapping for {len(weight_name_mapping)} weight names") |
|
|
|
new_weight_map = {} |
|
|
|
if reference_weight_map and "weight_map" in reference_weight_map: |
|
devstral_weight_map = reference_weight_map["weight_map"] |
|
logger.info(f"Using Devstral reference weight map with {len(devstral_weight_map)} entries") |
|
|
|
for original_name, new_name in weight_name_mapping.items(): |
|
if new_name in devstral_weight_map: |
|
new_weight_map[new_name] = devstral_weight_map[new_name] |
|
else: |
|
logger.warning(f"Weight {new_name} not found in Devstral reference map") |
|
else: |
|
logger.warning("No Devstral reference map available, using original partitioning") |
|
for original_name, new_name in weight_name_mapping.items(): |
|
new_weight_map[new_name] = original_weight_map[original_name] |
|
|
|
|
|
file_to_weights = {} |
|
for new_name, file_name in new_weight_map.items(): |
|
if file_name not in file_to_weights: |
|
file_to_weights[file_name] = [] |
|
|
|
original_names = [orig for orig, new in weight_name_mapping.items() if new == new_name] |
|
if original_names: |
|
file_to_weights[file_name].append((original_names[0], new_name)) |
|
|
|
os.makedirs(Path(output_path), exist_ok=True) |
|
|
|
|
|
for file_name, weight_pairs in file_to_weights.items(): |
|
logger.info(f"Processing {file_name} with {len(weight_pairs)} weights") |
|
|
|
tensors_to_save = {} |
|
|
|
for original_name, new_name in weight_pairs: |
|
original_file = original_weight_map.get(original_name) |
|
if not original_file: |
|
logger.warning(f"Original file not found for weight {original_name}") |
|
continue |
|
|
|
input_file_path = Path(model_path) / original_file |
|
if not input_file_path.exists(): |
|
logger.warning(f"File {input_file_path} does not exist, skipping") |
|
continue |
|
|
|
try: |
|
original_tensors = load_file(input_file_path) |
|
if original_name in original_tensors: |
|
tensors_to_save[new_name] = original_tensors[original_name] |
|
else: |
|
logger.warning(f"Weight {original_name} not found in {original_file}") |
|
except Exception as e: |
|
logger.error(f"Error loading {original_file}: {e}") |
|
|
|
if tensors_to_save: |
|
output_file_path = Path(output_path) / file_name |
|
try: |
|
save_file(tensors_to_save, output_file_path) |
|
logger.info(f"Saved {len(tensors_to_save)} weights to {file_name}") |
|
except Exception as e: |
|
logger.error(f"Error saving {file_name}: {e}") |
|
|
|
|
|
new_index = { |
|
"metadata": {"total_size": reference_weight_map.get("metadata", {}).get("total_size", 0)} |
|
if reference_weight_map else index_data.get("metadata", {}), |
|
"weight_map": new_weight_map |
|
} |
|
|
|
output_index_path = Path(output_path) / "model.safetensors.index.json" |
|
with open(output_index_path, "w") as f: |
|
json.dump(new_index, f, indent=2) |
|
|
|
logger.info(f"Saved transformed safetensors index to {output_index_path}") |
|
|
|
def copy_additional_files(model_path, output_path): |
|
"""Copy additional model files like tokenizer, generation config, etc.""" |
|
additional_files = [ |
|
"tokenizer.json", |
|
"tokenizer_config.json", |
|
"special_tokens_map.json", |
|
"generation_config.json" |
|
] |
|
|
|
for filename in additional_files: |
|
src_path = Path(model_path) / filename |
|
if src_path.exists(): |
|
dst_path = Path(output_path) / filename |
|
shutil.copy(src_path, dst_path) |
|
logger.info(f"Copied {filename} to output directory") |
|
else: |
|
logger.warning(f"File {filename} not found in model directory") |
|
|
|
def download_minimal_files(repo_id, output_dir, cache_dir=None): |
|
"""Download only the necessary files for transformation without the full model""" |
|
logger.info(f"Downloading minimal files from {repo_id}") |
|
|
|
|
|
files_to_download = [ |
|
"config.json", |
|
"model.safetensors.index.json", |
|
"tokenizer_config.json", |
|
"special_tokens_map.json", |
|
"generation_config.json" |
|
] |
|
|
|
downloaded_files = {} |
|
|
|
for filename in files_to_download: |
|
try: |
|
file_path = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
cache_dir=cache_dir, |
|
local_files_only=False |
|
) |
|
downloaded_files[filename] = file_path |
|
logger.info(f"Downloaded {filename} to {file_path}") |
|
except Exception as e: |
|
logger.warning(f"Failed to download {filename}: {e}") |
|
|
|
return downloaded_files |
|
|
|
def download_reference_weight_map(reference_model, cache_dir=None): |
|
"""Download reference model's weight map to use as a reference""" |
|
logger.info(f"Downloading reference weight map from {reference_model}") |
|
|
|
try: |
|
file_path = hf_hub_download( |
|
repo_id=reference_model, |
|
filename="model.safetensors.index.json", |
|
cache_dir=cache_dir, |
|
local_files_only=False |
|
) |
|
|
|
with open(file_path, "r") as f: |
|
reference_map = json.load(f) |
|
|
|
logger.info(f"Successfully loaded reference weight map with {len(reference_map.get('weight_map', {}))} weights") |
|
return reference_map |
|
except Exception as e: |
|
logger.error(f"Failed to download reference weight map: {e}") |
|
return None |
|
|
|
def download_reference_config(reference_model, cache_dir=None): |
|
"""Download reference model's config.json to use as a reference""" |
|
logger.info(f"Downloading reference config from {reference_model}") |
|
|
|
try: |
|
file_path = hf_hub_download( |
|
repo_id=reference_model, |
|
filename="config.json", |
|
cache_dir=cache_dir, |
|
local_files_only=False |
|
) |
|
|
|
with open(file_path, "r") as f: |
|
reference_config = json.load(f) |
|
|
|
logger.info(f"Successfully loaded reference config") |
|
return reference_config |
|
except Exception as e: |
|
logger.error(f"Failed to download reference config: {e}") |
|
return None |
|
|
|
def verify_model(output_path): |
|
"""Verify that the transformed model can be loaded without errors""" |
|
logger.info(f"Verifying transformed model at {output_path}") |
|
try: |
|
config = AutoConfig.from_pretrained(output_path) |
|
logger.info(f"Successfully loaded config with model_type={config.model_type}") |
|
|
|
|
|
|
|
AutoModelForCausalLM.from_config(config) |
|
logger.info("Successfully loaded model architecture from config") |
|
|
|
return True |
|
except Exception as e: |
|
logger.error(f"Error verifying model: {e}") |
|
return False |
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
input_model = args.input_model |
|
output_path = args.output_path |
|
cache_dir = args.cache_dir |
|
reference_model = args.reference_model |
|
|
|
|
|
reference_weight_map = download_reference_weight_map(reference_model, cache_dir) |
|
if not reference_weight_map: |
|
logger.warning("Could not download reference weight map. The weight partitioning may not match exactly.") |
|
|
|
reference_config = download_reference_config(reference_model, cache_dir) |
|
if not reference_config: |
|
logger.warning("Could not download reference config. The config may not match exactly.") |
|
|
|
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
|
|
if not os.path.exists(input_model) or not os.path.isdir(input_model): |
|
logger.info(f"Downloading model from {input_model}") |
|
try: |
|
model_path = snapshot_download( |
|
repo_id=input_model, |
|
cache_dir=cache_dir, |
|
local_files_only=False, |
|
ignore_patterns=["*consolidated*"] |
|
) |
|
except Exception as e: |
|
logger.error(f"Error downloading model: {e}") |
|
return |
|
else: |
|
model_path = input_model |
|
|
|
logger.info(f"Model path: {model_path}") |
|
|
|
|
|
config_path = os.path.join(model_path, "config.json") |
|
transform_config(config_path, output_path, reference_config) |
|
|
|
|
|
safetensors_index_path = os.path.join(model_path, "model.safetensors.index.json") |
|
transform_weights( |
|
model_path, |
|
output_path, |
|
safetensors_index_path, |
|
reference_weight_map=reference_weight_map |
|
) |
|
|
|
|
|
copy_additional_files(model_path, output_path) |
|
|
|
|
|
success = verify_model(output_path) |
|
|
|
if success: |
|
logger.info(f"Successfully transformed model to {output_path}") |
|
else: |
|
logger.error(f"Failed to transform model properly") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|