#!/usr/bin/env python3 import torch from safetensors.torch import save_file as safe_save_file from transformers.models.wav2vec2.convert_wav2vec2_original_pytorch_checkpoint_to_pytorch import load_wav2vec2_layer langs = ["afr", "amh", "ara", "asm", "ast", "azj-script_latin", "bel", "ben", "bos", "bul", "cat", "ceb", "ces", "ckb", "cmn-script_simplified", "cym", "dan", "deu", "ell", "eng", "est", "fas", "fin", "fra", "ful", "gle", "glg", "guj", "hau", "heb", "hin", "hrv", "hun", "hye", "ibo", "ind", "isl", "ita", "jav", "jpn", "kam", "kan", "kat", "kaz", "kea", "khm", "kir", "kor", "lao", "lav", "lin", "lit", "ltz", "lug", "luo", "mal", "mar", "mkd", "mlt", "mon", "mri", "mya", "nld", "nob", "npi", "nso", "nya", "oci", "orm", "ory", "pan", "pol", "por", "pus", "ron", "rus", "slk", "slv", "sna", "snd", "som", "spa", "srp-script_latin", "swe", "swh", "tam", "tel", "tgk", "tgl", "tha", "tur", "ukr", "umb", "urd-script_arabic", "uzb-script_latin", "vie", "wol", "xho", "yor", "yue-script_traditional", "zlm", "zul"] sd = torch.load("../mms1b_fl102.pt") for lang in langs: hf_dict = {} fsq_adapters = sd["adapter"][lang]["model"] for k, v in fsq_adapters.items(): renamed_adapters = load_wav2vec2_layer(k, v, hf_dict=hf_dict) torch.save(hf_dict, f"./adapter.{lang}.bin") safe_save_file(hf_dict, f"./adapter.{lang}.safetensors", metadata={"format": "pt"})