Devstral-Vision-Small-2507 / make_devstral_vision.py
ehartford's picture
Upload folder using huggingface_hub
8adcf73 verified
import torch
from transformers import AutoModelForCausalLM, Mistral3ForConditionalGeneration, AutoTokenizer
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from tqdm import tqdm
def copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path):
"""
Copy Devstral language model weights to Mistral-Small model,
preserving Mistral's vision components.
"""
print(f"Loading Devstral model from {devstral_id}...")
devstral_model = AutoModelForCausalLM.from_pretrained(
devstral_id,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
print(f"Loading Mistral-Small model from {mistral_id}...")
mistral_model = Mistral3ForConditionalGeneration.from_pretrained(
mistral_id,
torch_dtype=torch.bfloat16,
device_map="cpu"
)
print("Fixing generation configuration...")
if hasattr(mistral_model, 'generation_config') and mistral_model.generation_config is not None:
gen_config = mistral_model.generation_config
# Fix the conflicting settings
if hasattr(gen_config, 'do_sample') and hasattr(gen_config, 'temperature'):
if not gen_config.do_sample and gen_config.temperature is not None:
# Option 1: Remove temperature (recommended for deterministic generation)
gen_config.temperature = None
print(" - Removed temperature setting (keeping do_sample=False)")
# Option 2: Alternative - enable sampling
# gen_config.do_sample = True
# print(" - Enabled sampling (keeping temperature=0.15)")
# Validate the config
try:
gen_config.validate()
print(" - Generation config is now valid")
except Exception as e:
print(f" - Warning: Generation config validation failed: {e}")
devstral_state = devstral_model.state_dict()
mistral_state = mistral_model.state_dict()
print("Copying weights from Devstral to Mistral-Small...")
weight_mappings = [
("model.embed_tokens.weight", "model.language_model.embed_tokens.weight"),
("model.norm.weight", "model.language_model.norm.weight")
]
for devstral_key, mistral_key in weight_mappings:
print(f"Copying {devstral_key} to {mistral_key}")
if devstral_key not in devstral_state or mistral_key not in mistral_state:
# abort if any key is missing
raise KeyError(f"Missing key: {devstral_key} or {mistral_key}")
mistral_state[mistral_key] = devstral_state[devstral_key].clone()
# Copy all other weights from Devstral to Mistral
for i in tqdm(range(40), desc="Copying layer weights"):
layer_mappings = [
# layers.[0-39].attention.wk.weight,[1024,5120],BF16
# (f"model.layers.{i}.self_attn.wk.weight", f"layers.{i}.self_attn.wk.weight"),
# (f"model.layers.{i}.self_attn.wo.weight", f"layers.{i}.self_attn.wo.weight"),
# (f"model.layers.{i}.self_attn.wq.weight", f"layers.{i}.self_attn.wq.weight"),
# (f"model.layers.{i}.self_attn.wv.weight", f"layers.{i}.self_attn.wv.weight"),
# (f"model.layers.{i}.attention_norm.weight", f"layers.{i}.attention_norm.weight"),
# (f"model.layers.{i}.feed_forward.w1.weight", f"layers.{i}.feed_forward.w1.weight"),
# (f"model.layers.{i}.feed_forward.w2.weight", f"layers.{i}.feed_forward.w2.weight"),
# (f"model.layers.{i}.feed_forward.w3.weight", f"layers.{i}.feed_forward.w3.weight"),
# (f"model.layers.{i}.ffn_norm.weight", f"layers.{i}.ffn_norm.weight"),
(f"model.layers.{i}.input_layernorm.weight", f"model.language_model.layers.{i}.input_layernorm.weight"),
(f"model.layers.{i}.mlp.down_proj.weight", f"model.language_model.layers.{i}.mlp.down_proj.weight"),
(f"model.layers.{i}.mlp.gate_proj.weight", f"model.language_model.layers.{i}.mlp.gate_proj.weight"),
(f"model.layers.{i}.mlp.up_proj.weight", f"model.language_model.layers.{i}.mlp.up_proj.weight"),
(f"model.layers.{i}.post_attention_layernorm.weight", f"model.language_model.layers.{i}.post_attention_layernorm.weight"),
(f"model.layers.{i}.self_attn.k_proj.weight", f"model.language_model.layers.{i}.self_attn.k_proj.weight"),
(f"model.layers.{i}.self_attn.o_proj.weight", f"model.language_model.layers.{i}.self_attn.o_proj.weight"),
(f"model.layers.{i}.self_attn.q_proj.weight", f"model.language_model.layers.{i}.self_attn.q_proj.weight"),
(f"model.layers.{i}.self_attn.v_proj.weight", f"model.language_model.layers.{i}.self_attn.v_proj.weight"),
]
for devstral_key, mistral_key in layer_mappings:
if devstral_key not in devstral_state or mistral_key not in mistral_state:
raise KeyError(f"Missing key: {devstral_key} or {mistral_key}")
mistral_state[mistral_key] = devstral_state[devstral_key].clone()
print("Saving updated Mistral-Small model...")
mistral_model.load_state_dict(mistral_state)
mistral_model.save_pretrained(output_path, safe_serialization=True)
if __name__ == "__main__":
devstral_id = "mistralai/Devstral-Small-2507"
mistral_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
output_path = "./Devstral-Vision-Small-2507"
model = copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path)