File size: 9,565 Bytes
e80739d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import gguf
import argparse
import logging
import torch
import json # Import json
from typing import Union
from pathlib import Path
from torch import Tensor
from transformers import MimiModel, PreTrainedModel
logger = logging.getLogger("mimi")
class MimiModelConverter:
mimi_model: PreTrainedModel
gguf_writer: gguf.GGUFWriter
fname_out: Path
ftype: gguf.LlamaFileType
def __init__(self,
pretrained_model_name_or_path: Union[Path, str],
fname_out: Path,
ftype: gguf.LlamaFileType,
is_big_endian: bool,):
# --- Load Model ---
self.mimi_model = MimiModel.from_pretrained(pretrained_model_name_or_path)
self.config = self.mimi_model.config # Store config for easier access
logger.info(f"Loaded model config: {self.config}")
self.fname_out = fname_out
self.ftype = ftype
endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
# --- Initialize GGUF Writer ---
self.gguf_writer = gguf.GGUFWriter(
path=None, # Path set during write
arch="mimi", # Set arch to 'mimi' instead of warning message
endianess=endianess)
# --- Add Metadata ---
logger.info("Adding metadata keys...")
# General Mimi parameters (adjust key names if C++ code expects differently)
# self.gguf_writer.add_architecture() # Explicitly set architecture (REMOVED - handled by init)
self.gguf_writer.add_uint32("mimi.sample_rate", self.config.sampling_rate)
self.gguf_writer.add_uint32("mimi.hidden_size", self.config.hidden_size) # Assuming a general hidden size if available
self.gguf_writer.add_uint32("mimi.num_hidden_layers", self.config.num_hidden_layers) # The one confirmed missing
self.gguf_writer.add_uint32("mimi.intermediate_size", self.config.intermediate_size)
# Encoder specific (assuming these exist in config)
if hasattr(self.config, 'encoder_hidden_size'):
self.gguf_writer.add_uint32("mimi.encoder.hidden_size", self.config.encoder_hidden_size)
# Add other encoder params if needed, e.g., embedding dim, num layers if different
# Decoder specific (assuming these exist in config)
if hasattr(self.config, 'decoder_hidden_size'):
self.gguf_writer.add_uint32("mimi.decoder.hidden_size", self.config.decoder_hidden_size)
# Add other decoder params if needed
# RVQ specific (check exact names in config.json or C++ code)
# Using common names found in similar models, adjust if needed.
if hasattr(self.config, 'num_codebooks'):
self.gguf_writer.add_uint32("mimi.rvq.num_quantizers", self.config.num_codebooks)
if hasattr(self.config, 'codebook_dim'):
self.gguf_writer.add_uint32("mimi.rvq.codebook_dim", self.config.codebook_dim)
if hasattr(self.config, 'codebook_size'):
self.gguf_writer.add_uint32("mimi.rvq.codebook_size", self.config.codebook_size) # Might be needed by C++
logger.info("Finished adding metadata keys.")
assert self.config.architectures[0] == "MimiModel"
# --- Load and Add Tensors ---
logger.info("Processing and adding tensors...")
for name, data_torch in self.mimi_model.state_dict().items():
# convert any unsupported data types to float32
old_dtype = data_torch.dtype
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
self.add_tensor(name, data_torch, old_dtype)
logger.info("Finished processing tensors.")
def add_tensor(self, name: str, data_torch: Tensor, old_dtype: torch.dtype):
is_1d = len(data_torch.shape) == 1
is_bias = ".bias" in name
can_quantize = not is_1d and not is_bias
data_qtype = gguf.GGMLQuantizationType.F32
n_head = self.mimi_model.config.num_attention_heads
n_kv_head = self.mimi_model.config.num_key_value_heads
if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = self.undo_permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
data_torch = self.undo_permute(data_torch, n_head, n_kv_head)
# process codebook
if ".codebook.initialized" in name:
# "initialized" tensor
state_dict = self.mimi_model.state_dict()
embed_sum = state_dict[name.replace(".initialized", ".embed_sum")]
cluster_usage = state_dict[name.replace(".initialized", ".cluster_usage")]
# see modeling_mimi.py --> MimiEuclideanCodebook
data_torch = embed_sum / cluster_usage.clamp(min=self.mimi_model.config.norm_eps)[:, None]
name = name.replace(".initialized", "")
# ignore processed tensors
if ".cluster_usage" in name or ".embed_sum" in name:
return
# transpose some tensors
if ".conv.bias" in name:
data_torch = data_torch.view((1, data_torch.shape[0]))
data_torch = data_torch.transpose(0, 1)
# change view 3d to 2d
if "quantizer" in name and "_proj." in name:
assert data_torch.shape[2] == 1
data_torch = data_torch.view((data_torch.shape[0], data_torch.shape[1]))
# shorten name, otherwise it will be too long for ggml to read
name = name.replace("_residual_vector_quantizer", "_rvq")
if can_quantize:
if self.ftype == gguf.LlamaFileType.ALL_F32:
data_qtype = gguf.GGMLQuantizationType.F32
elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
data_qtype = gguf.GGMLQuantizationType.F16
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
data_qtype = gguf.GGMLQuantizationType.Q8_0
else:
raise ValueError(f"Unsupported file type: {self.ftype}")
# Conv kernels are always F16
if ".conv.weight" in name:
data_qtype = gguf.GGMLQuantizationType.F16
data = data_torch.numpy()
try:
data = gguf.quants.quantize(data, data_qtype)
except Exception as e:
logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
data_qtype = gguf.GGMLQuantizationType.F16
data = gguf.quants.quantize(data, data_qtype)
# reverse shape to make it similar to the internal ggml dimension order
shape_str = f"{{\'{', '.join(str(n) for n in reversed(data_torch.shape))}\'}}"
# Reduce verbosity slightly by default, uncomment if needed for deep debug
# logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
def write(self):
self.gguf_writer.write_header_to_file(path=self.fname_out)
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close()
logger.info(f"Model successfully converted and saved to {self.fname_out}") # Added confirmation message
@staticmethod
def undo_permute(weights: Tensor, n_head: int, n_head_kv: int):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert Mimi safetensors model to GGUF with metadata",) # Updated description
parser.add_argument(
"--outfile", type=Path, default="kyutai-mimi.gguf",
help="path to write to",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
help="output format",
)
parser.add_argument(
"--bigendian", action="store_true",
help="model is executed on big endian machine",
)
parser.add_argument(
"model", type=str,
help="directory or model ID containing model file (if model ID is specified, download from Hugging Face hub)",
nargs="?",
default="kyutai/mimi",
)
parser.add_argument(
"--verbose", action="store_true",
help="increase output verbosity",
)
args = parser.parse_args()
if args.model is None:
parser.error("the following arguments are required: model")
else:
logging.basicConfig(level=logging.INFO)
dir_model = args.model
fname_out = args.outfile # Use outfile argument
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
}
logger.info(f"Loading model: {dir_model}")
with torch.inference_mode():
converter = MimiModelConverter(
pretrained_model_name_or_path=dir_model,
fname_out=fname_out, # Pass fname_out here
ftype=ftype_map[args.outtype],
is_big_endian=args.bigendian,
)
converter.write()
if __name__ == '__main__':
parse_args()
|