|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, Mistral3ForConditionalGeneration, AutoTokenizer |
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer |
|
from tqdm import tqdm |
|
|
|
def copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path): |
|
""" |
|
Copy Devstral language model weights to Mistral-Small model, |
|
preserving Mistral's vision components. |
|
""" |
|
|
|
print(f"Loading Devstral model from {devstral_id}...") |
|
devstral_model = AutoModelForCausalLM.from_pretrained( |
|
devstral_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu" |
|
) |
|
|
|
print(f"Loading Mistral-Small model from {mistral_id}...") |
|
mistral_model = Mistral3ForConditionalGeneration.from_pretrained( |
|
mistral_id, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu" |
|
) |
|
|
|
print("Fixing generation configuration...") |
|
if hasattr(mistral_model, 'generation_config') and mistral_model.generation_config is not None: |
|
gen_config = mistral_model.generation_config |
|
|
|
|
|
if hasattr(gen_config, 'do_sample') and hasattr(gen_config, 'temperature'): |
|
if not gen_config.do_sample and gen_config.temperature is not None: |
|
|
|
gen_config.temperature = None |
|
print(" - Removed temperature setting (keeping do_sample=False)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
gen_config.validate() |
|
print(" - Generation config is now valid") |
|
except Exception as e: |
|
print(f" - Warning: Generation config validation failed: {e}") |
|
|
|
devstral_state = devstral_model.state_dict() |
|
mistral_state = mistral_model.state_dict() |
|
|
|
print("Copying weights from Devstral to Mistral-Small...") |
|
|
|
weight_mappings = [ |
|
("model.embed_tokens.weight", "model.language_model.embed_tokens.weight"), |
|
("model.norm.weight", "model.language_model.norm.weight") |
|
] |
|
for devstral_key, mistral_key in weight_mappings: |
|
print(f"Copying {devstral_key} to {mistral_key}") |
|
if devstral_key not in devstral_state or mistral_key not in mistral_state: |
|
|
|
raise KeyError(f"Missing key: {devstral_key} or {mistral_key}") |
|
mistral_state[mistral_key] = devstral_state[devstral_key].clone() |
|
|
|
|
|
for i in tqdm(range(40), desc="Copying layer weights"): |
|
layer_mappings = [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(f"model.layers.{i}.input_layernorm.weight", f"model.language_model.layers.{i}.input_layernorm.weight"), |
|
(f"model.layers.{i}.mlp.down_proj.weight", f"model.language_model.layers.{i}.mlp.down_proj.weight"), |
|
(f"model.layers.{i}.mlp.gate_proj.weight", f"model.language_model.layers.{i}.mlp.gate_proj.weight"), |
|
(f"model.layers.{i}.mlp.up_proj.weight", f"model.language_model.layers.{i}.mlp.up_proj.weight"), |
|
(f"model.layers.{i}.post_attention_layernorm.weight", f"model.language_model.layers.{i}.post_attention_layernorm.weight"), |
|
(f"model.layers.{i}.self_attn.k_proj.weight", f"model.language_model.layers.{i}.self_attn.k_proj.weight"), |
|
(f"model.layers.{i}.self_attn.o_proj.weight", f"model.language_model.layers.{i}.self_attn.o_proj.weight"), |
|
(f"model.layers.{i}.self_attn.q_proj.weight", f"model.language_model.layers.{i}.self_attn.q_proj.weight"), |
|
(f"model.layers.{i}.self_attn.v_proj.weight", f"model.language_model.layers.{i}.self_attn.v_proj.weight"), |
|
] |
|
|
|
for devstral_key, mistral_key in layer_mappings: |
|
if devstral_key not in devstral_state or mistral_key not in mistral_state: |
|
raise KeyError(f"Missing key: {devstral_key} or {mistral_key}") |
|
mistral_state[mistral_key] = devstral_state[devstral_key].clone() |
|
|
|
print("Saving updated Mistral-Small model...") |
|
|
|
mistral_model.load_state_dict(mistral_state) |
|
mistral_model.save_pretrained(output_path, safe_serialization=True) |
|
|
|
if __name__ == "__main__": |
|
devstral_id = "mistralai/Devstral-Small-2507" |
|
mistral_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" |
|
output_path = "./Devstral-Vision-Small-2507" |
|
|
|
model = copy_devstral_weights_to_mistral(devstral_id, mistral_id, output_path) |