PhiMiX-2x2B / README.md
paulilioaica's picture
Update README.md
e298d48 verified
|
raw
history blame
35.5 kB
metadata
license: apache-2.0
tags:
  - moe
  - frankenmoe
  - merge
  - mergekit
  - lazymergekit
  - cognitivecomputations/dolphin-2_6-phi-2
  - rhysjones/phi-2-orange
base_model:
  - cognitivecomputations/dolphin-2_6-phi-2
  - rhysjones/phi-2-orange

PhiMiX-2x2B

PhiMiX-2x2B is a Mixure of Experts (MoE) made with the following models using mergekit:

©️ Credits

  • mlabonne's phixtral for the PhiConfig and inference code.
  • mergekit code which I tweaked (you can find the PhiConfig here) by mainly adding the config in the moe_mixtral.py script from mixtral branch.

⏱️ Benchmarks

Model AGIEval GPT4All TruthfulQA Bigbench Average
PhiMiX-2x2B 33.34 71.75 49.25 37.62 47.99
phixtral-4x2_8 33.91 70.44 48.78 37.68 47.7
phixtral-2x2_8 34.1 70.44 48.78 37.82 47.78
phi-2-orange 33.37 71.33 49.87 37.3 47.97
dolphin-2_6-phi-2 33.12 69.85 47.39 37.2 46.89

I have used bold to highlight this merge from the list, and italics to highlight it's base modes used in the merge, and then bold in the cells where it exceeds the performance of either.

🧩 Configuration

base_model: rhysjones/phi-2-orange
gate_mode: cheap_embed
dtype: float16
experts:
  - source_model: cognitivecomputations/dolphin-2_6-phi-2
    positive_prompts: ["research, logic, math, science"]
  - source_model: rhysjones/phi-2-orange
    positive_prompts: ["programming, reasoning"]

💻 Usage

!pip install -qU transformers bitsandbytes accelerate

from transformers import AutoTokenizer
import transformers
import torch

model = "paulilioaica/PhiMiX-2x2B"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    trust_remote_code=True,
    model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True,},
)

prompt="How many continents are there?"
input = f"Instruct: {prompt}\nOutput:"
outputs = pipeline(input, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])
Instruct: How many continents are there?
Output: There are seven continents: Africa, Antarctica, Asia, Europe, North America, Australia, and South America. The total number of continents on Earth is seven, including Antarctica, which is sometimes considered part of the continent of Antarctica or as its own continent.

It's important to note that the number of continents in popular education and geography is seven, but some sources may include Antarctica as its own continent, while others include it as part of the continent of Antarctica. Regardless of the exact categorization, there are seven continents that collectively make up the Earth's landmass.

The continents can be divided into several subregions, such as islands, archipelagos, and microcontinents, which are smaller land masses surrounded by water. These subregions can be considered part of the continents or their own unique entities, depending on the context.

Each continent has its own unique geography, climate, flora, fauna, and human cultures. The continents are interconnected through various landforms, bodies of water, and global trade routes.

In summary, there are seven continents on Earth, each with its own distinct characteristics and unique contributions to the world's diversity. While the number may vary depending on the categorization of Antarctica, all seven continents together make

♻️ Replicate this repo

beware this will only work with 2 phis, you might have to tinker in the naming thing for more layers

AFTER all the file modifications and run, you need to replace configs.json with the one from this repo AFTER that you need to add modeling_phi.py and configurations.phi from this repo to your repo

Steps

  1. Modify moe_mixtral.py from /content/mergekit/mergekit/scripts/mixtral_moe.py to your hf repo

mixtral_moe.py

# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import logging
import os
import sys
from typing import Dict, List, Optional, Union

import click
import torch
import tqdm
import transformers
import yaml
from pydantic import BaseModel
from transformers import (
    AutoModelForCausalLM,
    LlamaForCausalLM,
    MistralConfig,
    MistralForCausalLM,
    MixtralConfig,
)
from transformers.modeling_outputs import CausalLMOutputWithPast

import mergekit.architecture
from mergekit.common import ModelReference, dtype_from_name
from mergekit.io import LazyTensorLoader, TensorWriter
from mergekit.merge import MergeOptions
from mergekit.options import add_merge_options

# Create a Mixtral MoE from a set of equally-sized Mistral (or Llama) models.
# Takes the path to a yml config and an output path.
# Config schema is the two classes below.


class Expert(BaseModel):
    source_model: str

    positive_prompts: List[str]
    negative_prompts: Optional[List[str]] = None
    noise_scale: Optional[float] = None

    @property
    def model_ref(self):
        return ModelReference.parse(self.source_model)


class MistralMOEConfig(BaseModel):
    base_model: str
    experts: List[Expert]
    gate_mode: str = "hidden"  # possible values: "hidden", "cheap_embed", "random"
    # "hidden" uses hidden state vectors for the given prompts for each layer
    # "cheap_embed" uses the average of token embeddings for the prompts, same for each layer
    # "random" is random
    dtype: Optional[str] = None
    experts_per_token: int = 2


def get_hidden_states(
    model: Union[MistralForCausalLM, LlamaForCausalLM],
    tokenized: transformers.BatchEncoding,
    average: bool = True,
) -> List[torch.Tensor]:
    with torch.no_grad():
        output: CausalLMOutputWithPast = model(
            **tokenized.to(model.device), output_hidden_states=True, return_dict=True
        )
    hidden_states = torch.stack(
        output.hidden_states[:-1]
    )  # (num_layers, batch_size, seq_len, hidden_size)
    if average:
        # use average over sequence
        hidden_states = hidden_states.sum(dim=2) / hidden_states.shape[2]
    else:
        # take last value
        hidden_states = hidden_states[:, :, -1, :]
    return hidden_states.sum(dim=1) / hidden_states.shape[1]


def get_cheap_embedding(
    embed: torch.Tensor,
    tokenized: Dict[str, torch.Tensor],
    num_layers: int,
    vocab_size: int,
) -> torch.Tensor:
    onehot = torch.nn.functional.one_hot(
        tokenized["input_ids"], num_classes=vocab_size
    )  # (batch_size, seq_len, 32000)
    h = onehot.float() @ embed.float()  # (batch_size, seq_len, hidden_size)
    embedded = (
        (h * tokenized["attention_mask"].unsqueeze(-1))
        .sum(dim=1)
        .sum(dim=0, keepdim=True)
    )  # (1, hidden_size)
    res = embedded / embedded.norm(dim=-1, keepdim=True).clamp(
        min=1e-8
    )  # (1, hidden_size)
    return res.repeat(num_layers, 1)


def tokenize_prompts(
    prompts: List[str], tokenizer: transformers.PreTrainedTokenizerBase
):
    return tokenizer(
        [tokenizer.bos_token + p for p in prompts],
        return_tensors="pt",
        padding=True,
        add_special_tokens=False,
    )


def get_gate_params(
    model_ref: ModelReference,
    tokenizer: transformers.PreTrainedTokenizerBase,
    experts: List[Expert],
    mode: str = "hidden",
    load_in_4bit: bool = False,
    load_in_8bit: bool = False,
    lazy_unpickle: bool = False,
    trust_remote_code: bool = False,
    device: str = "auto",
):
    gate_vecs = []
    _do_it = None

    model_cfg = model_ref.config(trust_remote_code=trust_remote_code)

    if mode == "random":
        return torch.randn(
            (model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size)
        )
    elif mode == "cheap_embed":
        embed = LazyTensorLoader(
            model_ref.tensor_index(), lazy_unpickle=lazy_unpickle
        ).get_tensor("transformer.embd.wte.weight")

        def _do_it(tokenized):
            return get_cheap_embedding(
                embed,
                tokenized,
                num_layers=model_cfg.num_hidden_layers,
                vocab_size=model_cfg.vocab_size,
            )

    elif mode in ("hidden", "hidden_avg", "hidden_last"):
        model = AutoModelForCausalLM.from_pretrained(
            model_ref.model.path,
            revision=model_ref.model.revision,
            torch_dtype=torch.bfloat16,
            device_map=device,
            low_cpu_mem_usage=True,
            load_in_4bit=load_in_4bit,
            load_in_8bit=load_in_8bit,
            trust_remote_code=trust_remote_code,
        )

        def _do_it(tokenized):
            return get_hidden_states(
                model, tokenized=tokenized, average=mode == "hidden_avg"
            )


    gate_vecs = []
    print(experts)
    for expert in tqdm.tqdm(experts, desc="expert prompts"):
        print(_do_it)
        hidden_states = _do_it(tokenize_prompts(expert.positive_prompts, tokenizer))
        if expert.negative_prompts:
            hidden_states -= _do_it(
                tokenize_prompts(expert.negative_prompts, tokenizer)
            )

        hidden_states /= hidden_states.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
        gate_vecs.append(hidden_states)
    gate_vecs = torch.stack(gate_vecs, dim=0)  # (num_expert, num_layer, hidden_size)
    return gate_vecs.permute(1, 0, 2)


def warn_degenerate_gates(gate_vecs: torch.Tensor, threshold: float = 5.0):
    degen_indices = []
    num_layers, _num_experts, _hidden_size = gate_vecs.shape
    for idx in range(num_layers):
        c = torch.linalg.cond(gate_vecs[idx, :, :].float())
        if c > threshold:
            degen_indices.append(idx)

    if degen_indices:
        if len(degen_indices) == 1:
            layer_str = f"layer {degen_indices[0]}"
            verb = "has"
        elif len(degen_indices) == 2:
            layer_str = f"layers {' and '.join(map(str, degen_indices))}"
            verb = "have"
        elif len(degen_indices) >= num_layers:
            layer_str = "ALL layers"
            verb = "have"
        else:
            layer_str = (
                "layers "
                + ", ".join(map(str, degen_indices[:-1]))
                + ", and "
                + str(degen_indices[-1])
            )
            verb = "have"

        logging.warning(
            f"{layer_str} {verb} degenerate routing parameters "
            "- your prompts may be too similar."
        )
        logging.warning("One or more experts will be underutilized in your model.")


def is_bad_config(config: MistralMOEConfig, allow_all_same: bool = False) -> bool:
    if len(config.experts) < 2:
        logging.error("Must include at least two experts.")
        return True

    if config.gate_mode == "random":
        return False  # eh we're good

    def prompt_tup(e: Expert):
        return (tuple(e.positive_prompts), tuple(e.negative_prompts or []))

    # let's just nip this trend in the bud
    p_first = prompt_tup(config.experts[0])
    if all(prompt_tup(e) == p_first for e in config.experts[1:]):
        logging.error(
            "Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE."
        )
        logging.error(
            "For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert."
        )
        return True

    if not allow_all_same:
        if all(
            e.source_model == config.experts[0].source_model for e in config.experts[1:]
        ):
            logging.error(
                "All of your expert models are the same. This will produce "
                "a model that uses more resources but gives the exact same output. "
                "If you plan to train the model after merging, proceed with the "
                "--i-understand-this-is-not-useful-without-training flag."
            )
            return True


def build(
    config: MistralMOEConfig,
    out_path: str,
    merge_options: MergeOptions,
    load_in_4bit: bool = False,
    load_in_8bit: bool = False,
    device: str = "auto",
    allow_all_same: bool = False,
):
    if is_bad_config(config, allow_all_same=allow_all_same):
        sys.exit(1)

    if config.experts_per_token < 1:
        logging.error("Experts per token must be >= 1")
        sys.exit(1)
    if config.experts_per_token > len(config.experts):
        logging.error("Experts per token must be <= number of experts")
        sys.exit(1)

    base_model = ModelReference.parse(config.base_model)
    base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
    if not isinstance(base_cfg, MistralConfig):
        base_cfg_mistral = MistralConfig(**base_cfg.to_dict())
        base_cfg_mistral.sliding_window = None
        base_cfg_mistral.max_position_embeddings = base_cfg.max_position_embeddings
        base_cfg = base_cfg_mistral

    out_cfg = MixtralConfig(**base_cfg.to_dict())
    out_cfg.architectures = ["PhiForCausalLM"]
    out_cfg.num_local_experts = len(config.experts)
    out_cfg.num_experts_per_tok = config.experts_per_token
    out_cfg.sliding_window = None
    if config.dtype:
        out_cfg.torch_dtype = config.dtype
    out_cfg.save_pretrained(out_path)

    if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
        logging.warning(
            f"Your model has {out_cfg.num_local_experts} experts, which is "
            "not a power of two. The model will not be usable in llama.cpp."
        )

    loaders: Dict[ModelReference, LazyTensorLoader] = {}
    for model in tqdm.tqdm(
        [base_model] + [e.model_ref for e in config.experts], desc="Warm up loaders"
    ):
        loaders[model] = LazyTensorLoader(
            model.tensor_index(cache_dir=merge_options.transformers_cache),
            lazy_unpickle=merge_options.lazy_unpickle,
        )

    base_loader = loaders.get(base_model)
    writer = TensorWriter(
        out_path=out_path,
        max_shard_size=merge_options.out_shard_size,
        safe_serialization=merge_options.safe_serialization,
    )

    if config.dtype:
        out_dtype = dtype_from_name(config.dtype)
    elif base_cfg.torch_dtype:
        out_dtype = base_cfg.torch_dtype
        if isinstance(out_dtype, str):
            out_dtype = dtype_from_name(out_dtype)
    else:
        out_dtype = None

    logging.info("Copying parameters...")
    MISTRAL_INFO = mergekit.architecture.PHI2_INFO
    for tensor_name in MISTRAL_INFO.pre_weight_names + MISTRAL_INFO.post_weight_names:
        tensor = base_loader.get_tensor(tensor_name)
        if not out_dtype:
            # All else has failed, take the first dtype we see
            out_dtype = tensor.dtype
        writer.save_tensor(
            tensor_name, tensor.to(dtype=out_dtype), clone=merge_options.clone_tensors
        )
    set_of_seen_tensors = set()

    for name_format in tqdm.tqdm(MISTRAL_INFO.layer_weight_formats()):
        for layer_idx in range(base_cfg.num_hidden_layers):
            tensor_name = name_format.format(idx=layer_idx)
            if ".mlp.fc" in name_format:
                for moe_index, expert in enumerate(config.experts):
                    if tensor_name in set_of_seen_tensors:
                        expert_name = tensor_name.replace(
                            ".mlp.fc", f".moe.mlp.1.fc"
                        )
                    else:
                        expert_name = tensor_name.replace(
                              ".mlp.fc", f".moe.mlp.0.fc"
                          )
                        set_of_seen_tensors.add(tensor_name)

                    expert_loader = loaders.get(expert.model_ref)
                    tensor = expert_loader.get_tensor(tensor_name)
                    if expert.noise_scale:
                        tensor += torch.randn_like(tensor) * expert.noise_scale
                    writer.save_tensor(
                        expert_name, tensor.to(dtype=out_dtype), clone=True
                    )
                    print(expert_name, tensor_name)
                continue
            writer.save_tensor(
                tensor_name, base_loader.get_tensor(tensor_name).to(dtype=out_dtype)
            )

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        base_model.model.path, revision=base_model.model.revision
    )
    tokenizer.padding_side = "left"
    tokenizer.pad_token_id = tokenizer.bos_token_id

    logging.info("Getting gate parameters...")
    gate_vecs = get_gate_params(
        base_model,
        tokenizer,
        config.experts,
        mode=config.gate_mode,
        load_in_4bit=load_in_4bit,
        load_in_8bit=load_in_8bit,
        lazy_unpickle=merge_options.lazy_unpickle,
        trust_remote_code=merge_options.trust_remote_code,
        device=device,
    )
    # gate_vecs: (num_layers, num_experts, hidden_size)

    warn_degenerate_gates(gate_vecs)

    for layer_idx in range(base_cfg.num_hidden_layers):
        writer.save_tensor(
            f"transformer.h.{layer_idx}.moe.gate.weight",
            gate_vecs[layer_idx, :, :].contiguous().to(dtype=out_dtype),
        )
    writer.finalize()

    if merge_options.copy_tokenizer:
        logging.info("Saving tokenizer...")
        tokenizer.save_pretrained(out_path, safe_serialization=True)

    logging.info("Done.")


@click.command("mergekit-moe")
@click.argument("config_path", type=click.Path(exists=True, dir_okay=False))
@click.argument("out_path", type=click.Path())
@click.option(
    "--load-in-4bit",
    is_flag=True,
    type=bool,
    default=False,
    help="Load model in 4bit for computing hidden states",
)
@click.option(
    "--load-in-8bit",
    is_flag=True,
    type=bool,
    default=False,
    help="Load model in 8bit for computing hidden states",
)
@click.option(
    "--device",
    type=str,
    default="auto",
    help="Device to use to compute embeddings",
    show_default=True,
)
@click.option(
    "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
)
@click.option(
    "--i-understand-this-is-not-useful-without-training",
    type=bool,
    default=False,
    is_flag=True,
    help="Really make the questionable model you want.",
)
@add_merge_options
def main(
    config_path: str,
    out_path: str,
    load_in_4bit: bool,
    load_in_8bit: bool,
    device: str,
    merge_options: MergeOptions,
    verbose: bool,
    i_understand_this_is_not_useful_without_training: bool,
):
    logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)

    if merge_options.cuda:
        logging.warning(
            '--cuda is a no-op for mergekit-moe, use "--device cuda" instead'
        )

    with open(config_path, "r", encoding="utf-8") as file:
        config_source = file.read()

    config = MistralMOEConfig.model_validate(yaml.safe_load(config_source))
    build(
        config,
        out_path=out_path,
        merge_options=merge_options,
        load_in_4bit=load_in_4bit,
        load_in_8bit=load_in_8bit,
        device=device,
        allow_all_same=i_understand_this_is_not_useful_without_training,
    )

    if merge_options.write_model_card:
        # TODO: generate a README.md as well
        with open(
            os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8"
        ) as fp:
            fp.write(config_source)


if __name__ == "__main__":
    main()
  1. Modify architecture.py /content/mergekit/mergekit/architecture.py (this you can take from the link to the commit i have in description)

architecture.py

# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from abc import ABC, abstractmethod
from typing import List, Optional

from pydantic import BaseModel
from transformers import PretrainedConfig


class ArchitectureInfo(ABC):
    @abstractmethod
    def pre_weights(self) -> List[str]:
        """Return a list of all weights preceding the first layer."""
        ...

    @abstractmethod
    def post_weights(self) -> List[str]:
        """Return a list of all weights following the final layer."""
        ...

    @abstractmethod
    def layer_weight_formats(self) -> List[str]:
        """Return a list of format strings all weights associated with a layer."""
        ...

    @abstractmethod
    def embed_weights(self) -> List[str]:
        ...

    def num_layers(self, config: PretrainedConfig) -> int:
        return config.num_hidden_layers

    def num_layers_config_key(self) -> str:
        """Key in config that represents number of layers"""
        return "num_hidden_layers"


class StaticTensorNames(ArchitectureInfo, BaseModel, frozen=True):
    name: str

    pre_weight_names: List[str]  # weights applied before first layer
    post_weight_names: List[str]  # weights applied after last layer
    embed_weight_names: List[str]  # weights for embed/lm_head
    layer_prefix_format: str
    layer_weight_suffixes: List[str]
    num_layers_key: Optional[str] = None

    def pre_weights(self) -> List[str]:
        return self.pre_weight_names

    def post_weights(self) -> List[str]:
        return self.post_weight_names

    def embed_weights(self) -> List[str]:
        return self.embed_weight_names

    def layer_weight_formats(self) -> List[str]:
        res = []
        for suffix in self.layer_weight_suffixes:
            res.append(self.layer_prefix_format + "." + suffix)
        return res

    def num_layers_config_key(self) -> str:
        if self.num_layers_key:
            return self.num_layers_key
        return super().num_layers_config_key()

    def num_layers(self, config: PretrainedConfig) -> int:
        return getattr(config, self.num_layers_config_key())

    def all_weights(self, config: PretrainedConfig) -> List[str]:
        num_layers = self.num_layers(config)
        tensor_names = list(self.pre_weights())
        for layer_idx in range(num_layers):
            for f in self.layer_weight_formats():
                tensor_names.append(f.format(idx=layer_idx))
        tensor_names.extend(self.post_weights())
        return tensor_names


LLAMA_INFO = StaticTensorNames(
    name="LlamaForCausalLM",
    pre_weight_names=["model.embed_tokens.weight"],
    post_weight_names=["model.norm.weight", "lm_head.weight"],
    embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
    layer_prefix_format="model.layers.{idx}",
    layer_weight_suffixes=[
        "input_layernorm.weight",
        "mlp.up_proj.weight",
        "mlp.down_proj.weight",
        "mlp.gate_proj.weight",
        "post_attention_layernorm.weight",
        "self_attn.q_proj.weight",
        "self_attn.k_proj.weight",
        "self_attn.v_proj.weight",
        "self_attn.o_proj.weight",
    ],
)

MISTRAL_INFO = StaticTensorNames(
    name="MistralForCausalLM",
    # lol
    **LLAMA_INFO.model_dump(exclude=["name"]),
)


STABLELM_INFO = StaticTensorNames(
    name="StableLMEpochForCausalLM",
    post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"],
    layer_weight_suffixes=LLAMA_INFO.layer_weight_suffixes
    + [
        "input_layernorm.bias",
        "post_attention_layernorm.bias",
    ],
    **LLAMA_INFO.model_dump(
        exclude=["name", "layer_weight_suffixes", "post_weight_names"]
    ),
)

GPT_NEOX_INFO = StaticTensorNames(
    name="GPTNeoXForCausalLM",
    pre_weight_names=["gpt_neox.embed_in.weight"],
    post_weight_names=[
        "gpt_neox.final_layer_norm.bias",
        "gpt_neox.final_layer_norm.weight",
        "embed_out.weight",
    ],
    embed_weight_names=["gpt_neox.embed_in.weight", "embed_out.weight"],
    layer_prefix_format="gpt_neox.layers.{idx}",
    layer_weight_suffixes=sum(
        (
            [f"{prefix}.weight", f"{prefix}.bias"]
            for prefix in [
                "attention.dense",
                "attention.query_key_value",
                "input_layernorm",
                "mlp.dense_4h_to_h",
                "mlp.dense_h_to_4h",
                "post_attention_layernorm",
            ]
        ),
        start=[],
    )
    + ["attention.bias", "attention.masked_bias", "attention.rotary_emb.inv_freq"],
)

GPT2_INFO = StaticTensorNames(
    name="GPT2LMHeadModel",
    pre_weight_names=["wte.weight", "wpe.weight"],
    post_weight_names=["ln_f.weight", "ln_f.bias"],
    embed_weight_names=["wte.weight"],
    layer_prefix_format="h.{idx}",
    layer_weight_suffixes=[
        "attn.c_attn.weight",
        "attn.c_attn.bias",
        "attn.c_proj.weight",
        "attn.c_proj.bias",
        "ln_1.weight",
        "ln_1.bias",
        "ln_2.weight",
        "ln_2.bias",
        "mlp.c_proj.weight",
        "mlp.c_proj.bias",
        "mlp.c_fc.weight",
        "mlp.c_fc.bias",
        "mlp.c_proj.weight",
        "mlp.c_proj.bias",
    ],
    num_layers_key="n_layer",
)

JAIS_INFO = StaticTensorNames(
    name="JAISLMHeadModel",
    pre_weight_names=["transformer.wte.weight", "transformer.relative_pe.slopes"],
    post_weight_names=["transformer.ln_f.weight", "transformer.ln_f.bias"],
    embed_weight_names=["transformer.wte.weight"],
    layer_prefix_format="transformer.h.{idx}",
    layer_weight_suffixes=[
        "attn.c_attn.weight",
        "attn.c_attn.bias",
        "attn.c_proj.weight",
        "attn.c_proj.bias",
        "ln_1.weight",
        "ln_1.bias",
        "ln_2.weight",
        "ln_2.bias",
        "mlp.c_fc.weight",
        "mlp.c_fc.bias",
        "mlp.c_fc2.weight",
        "mlp.c_fc2.bias",
        "mlp.c_proj.weight",
        "mlp.c_proj.bias",
    ],
    num_layers_key="n_layer",
)

GPT2_SEQCLASS_INFO = StaticTensorNames(
    name="GPT2ForSequenceClassification",
    pre_weight_names=["transformer.wte.weight", "transformer.wpe.weight"],
    post_weight_names=[
        "transformer.ln_f.weight",
        "transformer.ln_f.bias",
        "score.weight",
    ],
    layer_prefix_format="transformer.h.{idx}",
    embed_weight_names=GPT2_INFO.embed_weight_names,
    layer_weight_suffixes=GPT2_INFO.layer_weight_suffixes,
    num_layers_key=GPT2_INFO.num_layers_key,
)


QWEN_INFO = StaticTensorNames(
    name="QWenLMHeadModel",
    pre_weight_names=["transformer.wte.weight"],
    post_weight_names=["transformer.ln_f.weight", "lm_head.weight"],
    embed_weight_names=["transformer.wte.weight", "lm_head.weight"],
    layer_prefix_format="transformer.h.{idx}",
    layer_weight_suffixes=[
        "attn.c_attn.bias",
        "attn.c_attn.weight",
        "attn.c_proj.weight",
        "ln_1.weight",
        "ln_2.weight",
        "mlp.c_proj.weight",
        "mlp.w1.weight",
        "mlp.w2.weight",
    ],
)

CHATGLM_INFO = StaticTensorNames(
    name="ChatGLMModel",
    pre_weight_names=[
        "transformer.embedding.word_embeddings.weight",
        "transformer.rotary_pos_emb.inv_freq",
    ],
    post_weight_names=[
        "transformer.encoder.final_layernorm.weight",
        "transformer.output_layer.weight",
    ],
    embed_weight_names=[
        "transformer.embedding.word_embeddings.weight",
        "transformer.output_layer.weight",
    ],
    layer_prefix_format="transformer.encoder.layers.{idx}",
    layer_weight_suffixes=[
        "input_layernorm.weight",
        "mlp.dense_4h_to_h.weight",
        "mlp.dense_h_to_4h.weight",
        "post_attention_layernorm.weight",
        "self_attention.dense.weight",
        "self_attention.query_key_value.bias",
        "self_attention.query_key_value.weight",
    ],
)

FALCON_INFO = StaticTensorNames(
    name="FalconForCausalLM",
    pre_weight_names=["transformer.word_embeddings.weight"],
    post_weight_names=[
        "transformer.ln_f.weight",
        "transformer.ln_f.bias",
        "lm_head.weight",
    ],
    embed_weight_names=["transformer.word_embeddings.weight", "lm_head.weight"],
    layer_prefix_format="transformer.h.{idx}",
    layer_weight_suffixes=[
        "ln_attn.bias",
        "ln_attn.weight",
        "ln_mlp.bias",
        "ln_mlp.weight",
        "mlp.dense_4h_to_h.weight",
        "mlp.dense_h_to_4h.weight",
        "self_attention.dense.weight",
        "self_attention.query_key_value.weight",
    ],
)


class PhiTensorNames(ArchitectureInfo):
    architecture_name: str = "MixFormerSequentialForCausalLM"

    def __init__(self, config: PretrainedConfig):
        self.config = config

    def __eq__(self, rhs: "PhiTensorNames"):
        if not isinstance(rhs, PhiTensorNames):
            return False
        return self.num_layers() == rhs.num_layers()

    def pre_weights(self) -> List[str]:
        return ["layers.0.wte.weight"]

    def post_weights(self) -> List[str]:
        fake_layer_idx = self.config.n_layer + 1
        return [
            f"layers.{fake_layer_idx}.{suffix}"
            for suffix in ["linear.bias", "linear.weight", "ln.bias", "ln.weight"]
        ]

    def embed_weights(self) -> List[str]:
        fake_layer_idx = self.config.n_layer + 1
        return [
            "layers.0.wte.weight",
            f"layers.{fake_layer_idx}.linear.weight",
            f"layers.{fake_layer_idx}.linear.bias",
        ]

    def layer_weight_formats(self) -> List[str]:
        return [
            ("layers.{idx}." + suffix)
            for suffix in [
                "ln.bias",
                "ln.weight",
                "mixer.Wqkv.bias",
                "mixer.Wqkv.weight",
                "mixer.out_proj.bias",
                "mixer.out_proj.weight",
                "mixer.rotary_emb.inv_freq",
                "mlp.fc1.bias",
                "mlp.fc1.weight",
                "mlp.fc2.bias",
                "mlp.fc2.weight",
            ]
        ]

    def num_layers(self, config: PretrainedConfig) -> int:
        return config.n_layer

    def num_layers_config_key(self) -> str:
        return "n_layer"


PHI2_INFO = StaticTensorNames(
    name="PhiForCausalLM",
    pre_weight_names=["transformer.embd.wte.weight"],
    post_weight_names=[
        "lm_head.linear.bias",
        "lm_head.linear.weight",
        "lm_head.ln.bias",
        "lm_head.ln.weight",
    ],
    embed_weight_names=["lm_head.linear.weight", "transformer.embd.wte.weight"],
    layer_prefix_format="transformer.h.{idx}",
    layer_weight_suffixes=[
        "ln.bias",
        "ln.weight",
        "mixer.out_proj.bias",
        "mixer.out_proj.weight",
        "mixer.Wqkv.bias",
        "mixer.Wqkv.weight",
        "mlp.fc1.bias",
        "mlp.fc1.weight",
        "mlp.fc2.bias",
        "mlp.fc2.weight",
    ],
    num_layers_key="n_layer",
)


PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticTensorNames(
    name="PhiForCausalLM",
    pre_weight_names=["model.embed_tokens.weight"],
    post_weight_names=[
        "lm_head.bias",
        "lm_head.weight",
        "model.final_layernorm.bias",
        "model.final_layernorm.weight",
    ],
    embed_weight_names=["lm_head.weight", "model.embed_tokens.weight"],
    layer_prefix_format="model.layers.{idx}",
    layer_weight_suffixes=[
        "input_layernorm.bias",
        "input_layernorm.weight",
        "self_attn.dense.bias",
        "self_attn.dense.weight",
        "self_attn.q_proj.bias",
        "self_attn.q_proj.weight",
        "self_attn.k_proj.bias",
        "self_attn.k_proj.weight",
        "self_attn.v_proj.bias",
        "self_attn.v_proj.weight",
        "mlp.fc1.bias",
        "mlp.fc1.weight",
        "mlp.fc2.bias",
        "mlp.fc2.weight",
    ],
)


BAICHUAN_INFO = StaticTensorNames(
    name="BaichuanForCausalLM",
    pre_weight_names=["model.embed_tokens.weight"],
    post_weight_names=["model.norm.weight", "lm_head.weight"],
    embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
    layer_prefix_format="model.layers.{idx}",
    layer_weight_suffixes=[
        "input_layernorm.weight",
        "self_attn.W_pack.weight",
        "self_attn.o_proj.weight",
        "post_attention_layernorm.weight",
        "mlp.gate_proj.weight",
        "mlp.down_proj.weight",
        "mlp.up_proj.weight",
    ],
)


def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
    if len(config.architectures) != 1:
        raise RuntimeError("More than one architecture in config?")

    arch_name = config.architectures[0]
    if arch_name == PhiTensorNames.architecture_name:
        return PhiTensorNames(config)

    if arch_name == PHI2_INFO.name:
        if config.model_type == "phi-msft":
            return PHI2_INFO
        elif config.model_type == "phi":
            return PHI2_INFO_AGAIN_BUT_DIFFERENT

    supported = [
        LLAMA_INFO,
        MISTRAL_INFO,
        GPT_NEOX_INFO,
        QWEN_INFO,
        GPT2_INFO,
        GPT2_SEQCLASS_INFO,
        CHATGLM_INFO,
        STABLELM_INFO,
        JAIS_INFO,
        BAICHUAN_INFO,
        FALCON_INFO,
    ]
    for arch in supported:
        if arch.name == arch_name:
            return arch

    raise RuntimeError(f"Unsupported architecture {arch_name}")
  1. replace configs.json with the one from this repo
  2. you need to add modeling_phi.py and configurations.phi from this repo to your repo