GGUF
File size: 9,565 Bytes
e80739d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import gguf
import argparse
import logging
import torch
import json  # Import json
from typing import Union
from pathlib import Path
from torch import Tensor
from transformers import MimiModel, PreTrainedModel

logger = logging.getLogger("mimi")


class MimiModelConverter:
    mimi_model: PreTrainedModel
    gguf_writer: gguf.GGUFWriter
    fname_out: Path
    ftype: gguf.LlamaFileType

    def __init__(self,
                 pretrained_model_name_or_path: Union[Path, str],
                 fname_out: Path,
                 ftype: gguf.LlamaFileType,
                 is_big_endian: bool,):
        # --- Load Model ---
        self.mimi_model = MimiModel.from_pretrained(pretrained_model_name_or_path)
        self.config = self.mimi_model.config # Store config for easier access
        logger.info(f"Loaded model config: {self.config}")

        self.fname_out = fname_out
        self.ftype = ftype
        endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
        # --- Initialize GGUF Writer ---
        self.gguf_writer = gguf.GGUFWriter(
            path=None, # Path set during write
            arch="mimi", # Set arch to 'mimi' instead of warning message
            endianess=endianess)

        # --- Add Metadata ---
        logger.info("Adding metadata keys...")
        # General Mimi parameters (adjust key names if C++ code expects differently)
        # self.gguf_writer.add_architecture() # Explicitly set architecture (REMOVED - handled by init)
        self.gguf_writer.add_uint32("mimi.sample_rate", self.config.sampling_rate)
        self.gguf_writer.add_uint32("mimi.hidden_size", self.config.hidden_size) # Assuming a general hidden size if available
        self.gguf_writer.add_uint32("mimi.num_hidden_layers", self.config.num_hidden_layers) # The one confirmed missing
        self.gguf_writer.add_uint32("mimi.intermediate_size", self.config.intermediate_size)

        # Encoder specific (assuming these exist in config)
        if hasattr(self.config, 'encoder_hidden_size'):
            self.gguf_writer.add_uint32("mimi.encoder.hidden_size", self.config.encoder_hidden_size)
        # Add other encoder params if needed, e.g., embedding dim, num layers if different

        # Decoder specific (assuming these exist in config)
        if hasattr(self.config, 'decoder_hidden_size'):
            self.gguf_writer.add_uint32("mimi.decoder.hidden_size", self.config.decoder_hidden_size)
        # Add other decoder params if needed

        # RVQ specific (check exact names in config.json or C++ code)
        # Using common names found in similar models, adjust if needed.
        if hasattr(self.config, 'num_codebooks'):
            self.gguf_writer.add_uint32("mimi.rvq.num_quantizers", self.config.num_codebooks)
        if hasattr(self.config, 'codebook_dim'):
             self.gguf_writer.add_uint32("mimi.rvq.codebook_dim", self.config.codebook_dim)
        if hasattr(self.config, 'codebook_size'):
             self.gguf_writer.add_uint32("mimi.rvq.codebook_size", self.config.codebook_size) # Might be needed by C++

        logger.info("Finished adding metadata keys.")

        assert self.config.architectures[0] == "MimiModel"

        # --- Load and Add Tensors ---
        logger.info("Processing and adding tensors...")
        for name, data_torch in self.mimi_model.state_dict().items():
            # convert any unsupported data types to float32
            old_dtype = data_torch.dtype
            if data_torch.dtype not in (torch.float16, torch.float32):
                data_torch = data_torch.to(torch.float32)
            self.add_tensor(name, data_torch, old_dtype)
        logger.info("Finished processing tensors.")

    def add_tensor(self, name: str, data_torch: Tensor, old_dtype: torch.dtype):
        is_1d = len(data_torch.shape) == 1
        is_bias = ".bias" in name
        can_quantize = not is_1d and not is_bias
        data_qtype = gguf.GGMLQuantizationType.F32

        n_head = self.mimi_model.config.num_attention_heads
        n_kv_head = self.mimi_model.config.num_key_value_heads
        if name.endswith(("q_proj.weight", "q_proj.bias")):
            data_torch = self.undo_permute(data_torch, n_head, n_head)
        if name.endswith(("k_proj.weight", "k_proj.bias")):
            data_torch = self.undo_permute(data_torch, n_head, n_kv_head)

        # process codebook
        if ".codebook.initialized" in name:
            # "initialized" tensor
            state_dict = self.mimi_model.state_dict()
            embed_sum = state_dict[name.replace(".initialized", ".embed_sum")]
            cluster_usage = state_dict[name.replace(".initialized", ".cluster_usage")]
            # see modeling_mimi.py --> MimiEuclideanCodebook
            data_torch = embed_sum / cluster_usage.clamp(min=self.mimi_model.config.norm_eps)[:, None]
            name = name.replace(".initialized", "")

        # ignore processed tensors
        if ".cluster_usage" in name or ".embed_sum" in name:
            return

        # transpose some tensors
        if ".conv.bias" in name:
            data_torch = data_torch.view((1, data_torch.shape[0]))
            data_torch = data_torch.transpose(0, 1)

        # change view 3d to 2d
        if "quantizer" in name and "_proj." in name:
            assert data_torch.shape[2] == 1
            data_torch = data_torch.view((data_torch.shape[0], data_torch.shape[1]))

        # shorten name, otherwise it will be too long for ggml to read
        name = name.replace("_residual_vector_quantizer", "_rvq")

        if can_quantize:
            if self.ftype == gguf.LlamaFileType.ALL_F32:
                data_qtype = gguf.GGMLQuantizationType.F32
            elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
                data_qtype = gguf.GGMLQuantizationType.F16
            elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
                data_qtype = gguf.GGMLQuantizationType.BF16
            elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
                data_qtype = gguf.GGMLQuantizationType.Q8_0
            else:
                raise ValueError(f"Unsupported file type: {self.ftype}")

        # Conv kernels are always F16
        if ".conv.weight" in name:
            data_qtype = gguf.GGMLQuantizationType.F16

        data = data_torch.numpy()

        try:
            data = gguf.quants.quantize(data, data_qtype)
        except Exception as e:
            logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
            data_qtype = gguf.GGMLQuantizationType.F16
            data = gguf.quants.quantize(data, data_qtype)

        # reverse shape to make it similar to the internal ggml dimension order
        shape_str = f"{{\'{', '.join(str(n) for n in reversed(data_torch.shape))}\'}}"
        # Reduce verbosity slightly by default, uncomment if needed for deep debug
        # logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")

        self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)

    def write(self):
        self.gguf_writer.write_header_to_file(path=self.fname_out)
        self.gguf_writer.write_kv_data_to_file()
        self.gguf_writer.write_tensors_to_file(progress=True)
        self.gguf_writer.close()
        logger.info(f"Model successfully converted and saved to {self.fname_out}") # Added confirmation message

    @staticmethod
    def undo_permute(weights: Tensor, n_head: int, n_head_kv: int):
        if n_head_kv is not None and n_head != n_head_kv:
            n_head = n_head_kv
        return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
                .swapaxes(1, 2)
                .reshape(weights.shape))

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Convert Mimi safetensors model to GGUF with metadata",) # Updated description
    parser.add_argument(
        "--outfile", type=Path, default="kyutai-mimi.gguf",
        help="path to write to",
    )
    parser.add_argument(
        "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
        help="output format",
    )
    parser.add_argument(
        "--bigendian", action="store_true",
        help="model is executed on big endian machine",
    )
    parser.add_argument(
        "model", type=str,
        help="directory or model ID containing model file (if model ID is specified, download from Hugging Face hub)",
        nargs="?",
        default="kyutai/mimi",
    )
    parser.add_argument(
        "--verbose", action="store_true",
        help="increase output verbosity",
    )

    args = parser.parse_args()
    if args.model is None:
        parser.error("the following arguments are required: model")
    else:
        logging.basicConfig(level=logging.INFO)

    dir_model = args.model
    fname_out = args.outfile # Use outfile argument

    ftype_map: dict[str, gguf.LlamaFileType] = {
        "f32": gguf.LlamaFileType.ALL_F32,
        "f16": gguf.LlamaFileType.MOSTLY_F16,
        "bf16": gguf.LlamaFileType.MOSTLY_BF16,
        "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
    }

    logger.info(f"Loading model: {dir_model}")

    with torch.inference_mode():
        converter = MimiModelConverter(
            pretrained_model_name_or_path=dir_model,
            fname_out=fname_out, # Pass fname_out here
            ftype=ftype_map[args.outtype],
            is_big_endian=args.bigendian,
        )
        converter.write()


if __name__ == '__main__':
    parse_args()