metadata
license: apache-2.0
tags:
- moe
- frankenmoe
- merge
- mergekit
- lazymergekit
- cognitivecomputations/dolphin-2_6-phi-2
- rhysjones/phi-2-orange
base_model:
- cognitivecomputations/dolphin-2_6-phi-2
- rhysjones/phi-2-orange
PhiMiX-2x2B
PhiMiX-2x2B is a Mixure of Experts (MoE) made with the following models using mergekit:
©️ Credits
- mlabonne's phixtral for the PhiConfig and inference code.
- mergekit code which I tweaked (you can find the PhiConfig here)
by mainly adding the config in the
moe_mixtral.py
script frommixtral
branch.
⏱️ Benchmarks
Model | AGIEval | GPT4All | TruthfulQA | Bigbench | Average |
---|---|---|---|---|---|
PhiMiX-2x2B | 33.34 | 71.75 | 49.25 | 37.62 | 47.99 |
phixtral-4x2_8 | 33.91 | 70.44 | 48.78 | 37.68 | 47.7 |
phixtral-2x2_8 | 34.1 | 70.44 | 48.78 | 37.82 | 47.78 |
phi-2-orange | 33.37 | 71.33 | 49.87 | 37.3 | 47.97 |
dolphin-2_6-phi-2 | 33.12 | 69.85 | 47.39 | 37.2 | 46.89 |
I have used bold to highlight this merge from the list, and italics to highlight it's base modes used in the merge, and then bold in the cells where it exceeds the performance of either.
🧩 Configuration
base_model: rhysjones/phi-2-orange
gate_mode: cheap_embed
dtype: float16
experts:
- source_model: cognitivecomputations/dolphin-2_6-phi-2
positive_prompts: ["research, logic, math, science"]
- source_model: rhysjones/phi-2-orange
positive_prompts: ["programming, reasoning"]
💻 Usage
!pip install -qU transformers bitsandbytes accelerate
from transformers import AutoTokenizer
import transformers
import torch
model = "paulilioaica/PhiMiX-2x2B"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
model=model,
trust_remote_code=True,
model_kwargs={"torch_dtype": torch.float16, "load_in_4bit": True,},
)
prompt="How many continents are there?"
input = f"Instruct: {prompt}\nOutput:"
outputs = pipeline(input, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])
Instruct: How many continents are there?
Output: There are seven continents: Africa, Antarctica, Asia, Europe, North America, Australia, and South America. The total number of continents on Earth is seven, including Antarctica, which is sometimes considered part of the continent of Antarctica or as its own continent.
It's important to note that the number of continents in popular education and geography is seven, but some sources may include Antarctica as its own continent, while others include it as part of the continent of Antarctica. Regardless of the exact categorization, there are seven continents that collectively make up the Earth's landmass.
The continents can be divided into several subregions, such as islands, archipelagos, and microcontinents, which are smaller land masses surrounded by water. These subregions can be considered part of the continents or their own unique entities, depending on the context.
Each continent has its own unique geography, climate, flora, fauna, and human cultures. The continents are interconnected through various landforms, bodies of water, and global trade routes.
In summary, there are seven continents on Earth, each with its own distinct characteristics and unique contributions to the world's diversity. While the number may vary depending on the categorization of Antarctica, all seven continents together make
♻️ Replicate this repo
beware this will only work with 2 phis, you might have to tinker in the naming thing for more layers
AFTER all the file modifications and run, you need to replace configs.json
with the one from this repo
AFTER that you need to add modeling_phi.py
and configurations.phi
from this repo to your repo
Steps
- Modify moe_mixtral.py from
/content/mergekit/mergekit/scripts/mixtral_moe.py
to your hf repo
mixtral_moe.py
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
import logging
import os
import sys
from typing import Dict, List, Optional, Union
import click
import torch
import tqdm
import transformers
import yaml
from pydantic import BaseModel
from transformers import (
AutoModelForCausalLM,
LlamaForCausalLM,
MistralConfig,
MistralForCausalLM,
MixtralConfig,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
import mergekit.architecture
from mergekit.common import ModelReference, dtype_from_name
from mergekit.io import LazyTensorLoader, TensorWriter
from mergekit.merge import MergeOptions
from mergekit.options import add_merge_options
# Create a Mixtral MoE from a set of equally-sized Mistral (or Llama) models.
# Takes the path to a yml config and an output path.
# Config schema is the two classes below.
class Expert(BaseModel):
source_model: str
positive_prompts: List[str]
negative_prompts: Optional[List[str]] = None
noise_scale: Optional[float] = None
@property
def model_ref(self):
return ModelReference.parse(self.source_model)
class MistralMOEConfig(BaseModel):
base_model: str
experts: List[Expert]
gate_mode: str = "hidden" # possible values: "hidden", "cheap_embed", "random"
# "hidden" uses hidden state vectors for the given prompts for each layer
# "cheap_embed" uses the average of token embeddings for the prompts, same for each layer
# "random" is random
dtype: Optional[str] = None
experts_per_token: int = 2
def get_hidden_states(
model: Union[MistralForCausalLM, LlamaForCausalLM],
tokenized: transformers.BatchEncoding,
average: bool = True,
) -> List[torch.Tensor]:
with torch.no_grad():
output: CausalLMOutputWithPast = model(
**tokenized.to(model.device), output_hidden_states=True, return_dict=True
)
hidden_states = torch.stack(
output.hidden_states[:-1]
) # (num_layers, batch_size, seq_len, hidden_size)
if average:
# use average over sequence
hidden_states = hidden_states.sum(dim=2) / hidden_states.shape[2]
else:
# take last value
hidden_states = hidden_states[:, :, -1, :]
return hidden_states.sum(dim=1) / hidden_states.shape[1]
def get_cheap_embedding(
embed: torch.Tensor,
tokenized: Dict[str, torch.Tensor],
num_layers: int,
vocab_size: int,
) -> torch.Tensor:
onehot = torch.nn.functional.one_hot(
tokenized["input_ids"], num_classes=vocab_size
) # (batch_size, seq_len, 32000)
h = onehot.float() @ embed.float() # (batch_size, seq_len, hidden_size)
embedded = (
(h * tokenized["attention_mask"].unsqueeze(-1))
.sum(dim=1)
.sum(dim=0, keepdim=True)
) # (1, hidden_size)
res = embedded / embedded.norm(dim=-1, keepdim=True).clamp(
min=1e-8
) # (1, hidden_size)
return res.repeat(num_layers, 1)
def tokenize_prompts(
prompts: List[str], tokenizer: transformers.PreTrainedTokenizerBase
):
return tokenizer(
[tokenizer.bos_token + p for p in prompts],
return_tensors="pt",
padding=True,
add_special_tokens=False,
)
def get_gate_params(
model_ref: ModelReference,
tokenizer: transformers.PreTrainedTokenizerBase,
experts: List[Expert],
mode: str = "hidden",
load_in_4bit: bool = False,
load_in_8bit: bool = False,
lazy_unpickle: bool = False,
trust_remote_code: bool = False,
device: str = "auto",
):
gate_vecs = []
_do_it = None
model_cfg = model_ref.config(trust_remote_code=trust_remote_code)
if mode == "random":
return torch.randn(
(model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size)
)
elif mode == "cheap_embed":
embed = LazyTensorLoader(
model_ref.tensor_index(), lazy_unpickle=lazy_unpickle
).get_tensor("transformer.embd.wte.weight")
def _do_it(tokenized):
return get_cheap_embedding(
embed,
tokenized,
num_layers=model_cfg.num_hidden_layers,
vocab_size=model_cfg.vocab_size,
)
elif mode in ("hidden", "hidden_avg", "hidden_last"):
model = AutoModelForCausalLM.from_pretrained(
model_ref.model.path,
revision=model_ref.model.revision,
torch_dtype=torch.bfloat16,
device_map=device,
low_cpu_mem_usage=True,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
trust_remote_code=trust_remote_code,
)
def _do_it(tokenized):
return get_hidden_states(
model, tokenized=tokenized, average=mode == "hidden_avg"
)
gate_vecs = []
print(experts)
for expert in tqdm.tqdm(experts, desc="expert prompts"):
print(_do_it)
hidden_states = _do_it(tokenize_prompts(expert.positive_prompts, tokenizer))
if expert.negative_prompts:
hidden_states -= _do_it(
tokenize_prompts(expert.negative_prompts, tokenizer)
)
hidden_states /= hidden_states.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
gate_vecs.append(hidden_states)
gate_vecs = torch.stack(gate_vecs, dim=0) # (num_expert, num_layer, hidden_size)
return gate_vecs.permute(1, 0, 2)
def warn_degenerate_gates(gate_vecs: torch.Tensor, threshold: float = 5.0):
degen_indices = []
num_layers, _num_experts, _hidden_size = gate_vecs.shape
for idx in range(num_layers):
c = torch.linalg.cond(gate_vecs[idx, :, :].float())
if c > threshold:
degen_indices.append(idx)
if degen_indices:
if len(degen_indices) == 1:
layer_str = f"layer {degen_indices[0]}"
verb = "has"
elif len(degen_indices) == 2:
layer_str = f"layers {' and '.join(map(str, degen_indices))}"
verb = "have"
elif len(degen_indices) >= num_layers:
layer_str = "ALL layers"
verb = "have"
else:
layer_str = (
"layers "
+ ", ".join(map(str, degen_indices[:-1]))
+ ", and "
+ str(degen_indices[-1])
)
verb = "have"
logging.warning(
f"{layer_str} {verb} degenerate routing parameters "
"- your prompts may be too similar."
)
logging.warning("One or more experts will be underutilized in your model.")
def is_bad_config(config: MistralMOEConfig, allow_all_same: bool = False) -> bool:
if len(config.experts) < 2:
logging.error("Must include at least two experts.")
return True
if config.gate_mode == "random":
return False # eh we're good
def prompt_tup(e: Expert):
return (tuple(e.positive_prompts), tuple(e.negative_prompts or []))
# let's just nip this trend in the bud
p_first = prompt_tup(config.experts[0])
if all(prompt_tup(e) == p_first for e in config.experts[1:]):
logging.error(
"Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE."
)
logging.error(
"For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert."
)
return True
if not allow_all_same:
if all(
e.source_model == config.experts[0].source_model for e in config.experts[1:]
):
logging.error(
"All of your expert models are the same. This will produce "
"a model that uses more resources but gives the exact same output. "
"If you plan to train the model after merging, proceed with the "
"--i-understand-this-is-not-useful-without-training flag."
)
return True
def build(
config: MistralMOEConfig,
out_path: str,
merge_options: MergeOptions,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
device: str = "auto",
allow_all_same: bool = False,
):
if is_bad_config(config, allow_all_same=allow_all_same):
sys.exit(1)
if config.experts_per_token < 1:
logging.error("Experts per token must be >= 1")
sys.exit(1)
if config.experts_per_token > len(config.experts):
logging.error("Experts per token must be <= number of experts")
sys.exit(1)
base_model = ModelReference.parse(config.base_model)
base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
if not isinstance(base_cfg, MistralConfig):
base_cfg_mistral = MistralConfig(**base_cfg.to_dict())
base_cfg_mistral.sliding_window = None
base_cfg_mistral.max_position_embeddings = base_cfg.max_position_embeddings
base_cfg = base_cfg_mistral
out_cfg = MixtralConfig(**base_cfg.to_dict())
out_cfg.architectures = ["PhiForCausalLM"]
out_cfg.num_local_experts = len(config.experts)
out_cfg.num_experts_per_tok = config.experts_per_token
out_cfg.sliding_window = None
if config.dtype:
out_cfg.torch_dtype = config.dtype
out_cfg.save_pretrained(out_path)
if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
logging.warning(
f"Your model has {out_cfg.num_local_experts} experts, which is "
"not a power of two. The model will not be usable in llama.cpp."
)
loaders: Dict[ModelReference, LazyTensorLoader] = {}
for model in tqdm.tqdm(
[base_model] + [e.model_ref for e in config.experts], desc="Warm up loaders"
):
loaders[model] = LazyTensorLoader(
model.tensor_index(cache_dir=merge_options.transformers_cache),
lazy_unpickle=merge_options.lazy_unpickle,
)
base_loader = loaders.get(base_model)
writer = TensorWriter(
out_path=out_path,
max_shard_size=merge_options.out_shard_size,
safe_serialization=merge_options.safe_serialization,
)
if config.dtype:
out_dtype = dtype_from_name(config.dtype)
elif base_cfg.torch_dtype:
out_dtype = base_cfg.torch_dtype
if isinstance(out_dtype, str):
out_dtype = dtype_from_name(out_dtype)
else:
out_dtype = None
logging.info("Copying parameters...")
MISTRAL_INFO = mergekit.architecture.PHI2_INFO
for tensor_name in MISTRAL_INFO.pre_weight_names + MISTRAL_INFO.post_weight_names:
tensor = base_loader.get_tensor(tensor_name)
if not out_dtype:
# All else has failed, take the first dtype we see
out_dtype = tensor.dtype
writer.save_tensor(
tensor_name, tensor.to(dtype=out_dtype), clone=merge_options.clone_tensors
)
set_of_seen_tensors = set()
for name_format in tqdm.tqdm(MISTRAL_INFO.layer_weight_formats()):
for layer_idx in range(base_cfg.num_hidden_layers):
tensor_name = name_format.format(idx=layer_idx)
if ".mlp.fc" in name_format:
for moe_index, expert in enumerate(config.experts):
if tensor_name in set_of_seen_tensors:
expert_name = tensor_name.replace(
".mlp.fc", f".moe.mlp.1.fc"
)
else:
expert_name = tensor_name.replace(
".mlp.fc", f".moe.mlp.0.fc"
)
set_of_seen_tensors.add(tensor_name)
expert_loader = loaders.get(expert.model_ref)
tensor = expert_loader.get_tensor(tensor_name)
if expert.noise_scale:
tensor += torch.randn_like(tensor) * expert.noise_scale
writer.save_tensor(
expert_name, tensor.to(dtype=out_dtype), clone=True
)
print(expert_name, tensor_name)
continue
writer.save_tensor(
tensor_name, base_loader.get_tensor(tensor_name).to(dtype=out_dtype)
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
base_model.model.path, revision=base_model.model.revision
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.bos_token_id
logging.info("Getting gate parameters...")
gate_vecs = get_gate_params(
base_model,
tokenizer,
config.experts,
mode=config.gate_mode,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
lazy_unpickle=merge_options.lazy_unpickle,
trust_remote_code=merge_options.trust_remote_code,
device=device,
)
# gate_vecs: (num_layers, num_experts, hidden_size)
warn_degenerate_gates(gate_vecs)
for layer_idx in range(base_cfg.num_hidden_layers):
writer.save_tensor(
f"transformer.h.{layer_idx}.moe.gate.weight",
gate_vecs[layer_idx, :, :].contiguous().to(dtype=out_dtype),
)
writer.finalize()
if merge_options.copy_tokenizer:
logging.info("Saving tokenizer...")
tokenizer.save_pretrained(out_path, safe_serialization=True)
logging.info("Done.")
@click.command("mergekit-moe")
@click.argument("config_path", type=click.Path(exists=True, dir_okay=False))
@click.argument("out_path", type=click.Path())
@click.option(
"--load-in-4bit",
is_flag=True,
type=bool,
default=False,
help="Load model in 4bit for computing hidden states",
)
@click.option(
"--load-in-8bit",
is_flag=True,
type=bool,
default=False,
help="Load model in 8bit for computing hidden states",
)
@click.option(
"--device",
type=str,
default="auto",
help="Device to use to compute embeddings",
show_default=True,
)
@click.option(
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
)
@click.option(
"--i-understand-this-is-not-useful-without-training",
type=bool,
default=False,
is_flag=True,
help="Really make the questionable model you want.",
)
@add_merge_options
def main(
config_path: str,
out_path: str,
load_in_4bit: bool,
load_in_8bit: bool,
device: str,
merge_options: MergeOptions,
verbose: bool,
i_understand_this_is_not_useful_without_training: bool,
):
logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)
if merge_options.cuda:
logging.warning(
'--cuda is a no-op for mergekit-moe, use "--device cuda" instead'
)
with open(config_path, "r", encoding="utf-8") as file:
config_source = file.read()
config = MistralMOEConfig.model_validate(yaml.safe_load(config_source))
build(
config,
out_path=out_path,
merge_options=merge_options,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
device=device,
allow_all_same=i_understand_this_is_not_useful_without_training,
)
if merge_options.write_model_card:
# TODO: generate a README.md as well
with open(
os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8"
) as fp:
fp.write(config_source)
if __name__ == "__main__":
main()
- Modify architecture.py
/content/mergekit/mergekit/architecture.py
(this you can take from the link to the commit i have in description)
architecture.py
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.
from abc import ABC, abstractmethod
from typing import List, Optional
from pydantic import BaseModel
from transformers import PretrainedConfig
class ArchitectureInfo(ABC):
@abstractmethod
def pre_weights(self) -> List[str]:
"""Return a list of all weights preceding the first layer."""
...
@abstractmethod
def post_weights(self) -> List[str]:
"""Return a list of all weights following the final layer."""
...
@abstractmethod
def layer_weight_formats(self) -> List[str]:
"""Return a list of format strings all weights associated with a layer."""
...
@abstractmethod
def embed_weights(self) -> List[str]:
...
def num_layers(self, config: PretrainedConfig) -> int:
return config.num_hidden_layers
def num_layers_config_key(self) -> str:
"""Key in config that represents number of layers"""
return "num_hidden_layers"
class StaticTensorNames(ArchitectureInfo, BaseModel, frozen=True):
name: str
pre_weight_names: List[str] # weights applied before first layer
post_weight_names: List[str] # weights applied after last layer
embed_weight_names: List[str] # weights for embed/lm_head
layer_prefix_format: str
layer_weight_suffixes: List[str]
num_layers_key: Optional[str] = None
def pre_weights(self) -> List[str]:
return self.pre_weight_names
def post_weights(self) -> List[str]:
return self.post_weight_names
def embed_weights(self) -> List[str]:
return self.embed_weight_names
def layer_weight_formats(self) -> List[str]:
res = []
for suffix in self.layer_weight_suffixes:
res.append(self.layer_prefix_format + "." + suffix)
return res
def num_layers_config_key(self) -> str:
if self.num_layers_key:
return self.num_layers_key
return super().num_layers_config_key()
def num_layers(self, config: PretrainedConfig) -> int:
return getattr(config, self.num_layers_config_key())
def all_weights(self, config: PretrainedConfig) -> List[str]:
num_layers = self.num_layers(config)
tensor_names = list(self.pre_weights())
for layer_idx in range(num_layers):
for f in self.layer_weight_formats():
tensor_names.append(f.format(idx=layer_idx))
tensor_names.extend(self.post_weights())
return tensor_names
LLAMA_INFO = StaticTensorNames(
name="LlamaForCausalLM",
pre_weight_names=["model.embed_tokens.weight"],
post_weight_names=["model.norm.weight", "lm_head.weight"],
embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
layer_prefix_format="model.layers.{idx}",
layer_weight_suffixes=[
"input_layernorm.weight",
"mlp.up_proj.weight",
"mlp.down_proj.weight",
"mlp.gate_proj.weight",
"post_attention_layernorm.weight",
"self_attn.q_proj.weight",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight",
"self_attn.o_proj.weight",
],
)
MISTRAL_INFO = StaticTensorNames(
name="MistralForCausalLM",
# lol
**LLAMA_INFO.model_dump(exclude=["name"]),
)
STABLELM_INFO = StaticTensorNames(
name="StableLMEpochForCausalLM",
post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"],
layer_weight_suffixes=LLAMA_INFO.layer_weight_suffixes
+ [
"input_layernorm.bias",
"post_attention_layernorm.bias",
],
**LLAMA_INFO.model_dump(
exclude=["name", "layer_weight_suffixes", "post_weight_names"]
),
)
GPT_NEOX_INFO = StaticTensorNames(
name="GPTNeoXForCausalLM",
pre_weight_names=["gpt_neox.embed_in.weight"],
post_weight_names=[
"gpt_neox.final_layer_norm.bias",
"gpt_neox.final_layer_norm.weight",
"embed_out.weight",
],
embed_weight_names=["gpt_neox.embed_in.weight", "embed_out.weight"],
layer_prefix_format="gpt_neox.layers.{idx}",
layer_weight_suffixes=sum(
(
[f"{prefix}.weight", f"{prefix}.bias"]
for prefix in [
"attention.dense",
"attention.query_key_value",
"input_layernorm",
"mlp.dense_4h_to_h",
"mlp.dense_h_to_4h",
"post_attention_layernorm",
]
),
start=[],
)
+ ["attention.bias", "attention.masked_bias", "attention.rotary_emb.inv_freq"],
)
GPT2_INFO = StaticTensorNames(
name="GPT2LMHeadModel",
pre_weight_names=["wte.weight", "wpe.weight"],
post_weight_names=["ln_f.weight", "ln_f.bias"],
embed_weight_names=["wte.weight"],
layer_prefix_format="h.{idx}",
layer_weight_suffixes=[
"attn.c_attn.weight",
"attn.c_attn.bias",
"attn.c_proj.weight",
"attn.c_proj.bias",
"ln_1.weight",
"ln_1.bias",
"ln_2.weight",
"ln_2.bias",
"mlp.c_proj.weight",
"mlp.c_proj.bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"mlp.c_proj.weight",
"mlp.c_proj.bias",
],
num_layers_key="n_layer",
)
JAIS_INFO = StaticTensorNames(
name="JAISLMHeadModel",
pre_weight_names=["transformer.wte.weight", "transformer.relative_pe.slopes"],
post_weight_names=["transformer.ln_f.weight", "transformer.ln_f.bias"],
embed_weight_names=["transformer.wte.weight"],
layer_prefix_format="transformer.h.{idx}",
layer_weight_suffixes=[
"attn.c_attn.weight",
"attn.c_attn.bias",
"attn.c_proj.weight",
"attn.c_proj.bias",
"ln_1.weight",
"ln_1.bias",
"ln_2.weight",
"ln_2.bias",
"mlp.c_fc.weight",
"mlp.c_fc.bias",
"mlp.c_fc2.weight",
"mlp.c_fc2.bias",
"mlp.c_proj.weight",
"mlp.c_proj.bias",
],
num_layers_key="n_layer",
)
GPT2_SEQCLASS_INFO = StaticTensorNames(
name="GPT2ForSequenceClassification",
pre_weight_names=["transformer.wte.weight", "transformer.wpe.weight"],
post_weight_names=[
"transformer.ln_f.weight",
"transformer.ln_f.bias",
"score.weight",
],
layer_prefix_format="transformer.h.{idx}",
embed_weight_names=GPT2_INFO.embed_weight_names,
layer_weight_suffixes=GPT2_INFO.layer_weight_suffixes,
num_layers_key=GPT2_INFO.num_layers_key,
)
QWEN_INFO = StaticTensorNames(
name="QWenLMHeadModel",
pre_weight_names=["transformer.wte.weight"],
post_weight_names=["transformer.ln_f.weight", "lm_head.weight"],
embed_weight_names=["transformer.wte.weight", "lm_head.weight"],
layer_prefix_format="transformer.h.{idx}",
layer_weight_suffixes=[
"attn.c_attn.bias",
"attn.c_attn.weight",
"attn.c_proj.weight",
"ln_1.weight",
"ln_2.weight",
"mlp.c_proj.weight",
"mlp.w1.weight",
"mlp.w2.weight",
],
)
CHATGLM_INFO = StaticTensorNames(
name="ChatGLMModel",
pre_weight_names=[
"transformer.embedding.word_embeddings.weight",
"transformer.rotary_pos_emb.inv_freq",
],
post_weight_names=[
"transformer.encoder.final_layernorm.weight",
"transformer.output_layer.weight",
],
embed_weight_names=[
"transformer.embedding.word_embeddings.weight",
"transformer.output_layer.weight",
],
layer_prefix_format="transformer.encoder.layers.{idx}",
layer_weight_suffixes=[
"input_layernorm.weight",
"mlp.dense_4h_to_h.weight",
"mlp.dense_h_to_4h.weight",
"post_attention_layernorm.weight",
"self_attention.dense.weight",
"self_attention.query_key_value.bias",
"self_attention.query_key_value.weight",
],
)
FALCON_INFO = StaticTensorNames(
name="FalconForCausalLM",
pre_weight_names=["transformer.word_embeddings.weight"],
post_weight_names=[
"transformer.ln_f.weight",
"transformer.ln_f.bias",
"lm_head.weight",
],
embed_weight_names=["transformer.word_embeddings.weight", "lm_head.weight"],
layer_prefix_format="transformer.h.{idx}",
layer_weight_suffixes=[
"ln_attn.bias",
"ln_attn.weight",
"ln_mlp.bias",
"ln_mlp.weight",
"mlp.dense_4h_to_h.weight",
"mlp.dense_h_to_4h.weight",
"self_attention.dense.weight",
"self_attention.query_key_value.weight",
],
)
class PhiTensorNames(ArchitectureInfo):
architecture_name: str = "MixFormerSequentialForCausalLM"
def __init__(self, config: PretrainedConfig):
self.config = config
def __eq__(self, rhs: "PhiTensorNames"):
if not isinstance(rhs, PhiTensorNames):
return False
return self.num_layers() == rhs.num_layers()
def pre_weights(self) -> List[str]:
return ["layers.0.wte.weight"]
def post_weights(self) -> List[str]:
fake_layer_idx = self.config.n_layer + 1
return [
f"layers.{fake_layer_idx}.{suffix}"
for suffix in ["linear.bias", "linear.weight", "ln.bias", "ln.weight"]
]
def embed_weights(self) -> List[str]:
fake_layer_idx = self.config.n_layer + 1
return [
"layers.0.wte.weight",
f"layers.{fake_layer_idx}.linear.weight",
f"layers.{fake_layer_idx}.linear.bias",
]
def layer_weight_formats(self) -> List[str]:
return [
("layers.{idx}." + suffix)
for suffix in [
"ln.bias",
"ln.weight",
"mixer.Wqkv.bias",
"mixer.Wqkv.weight",
"mixer.out_proj.bias",
"mixer.out_proj.weight",
"mixer.rotary_emb.inv_freq",
"mlp.fc1.bias",
"mlp.fc1.weight",
"mlp.fc2.bias",
"mlp.fc2.weight",
]
]
def num_layers(self, config: PretrainedConfig) -> int:
return config.n_layer
def num_layers_config_key(self) -> str:
return "n_layer"
PHI2_INFO = StaticTensorNames(
name="PhiForCausalLM",
pre_weight_names=["transformer.embd.wte.weight"],
post_weight_names=[
"lm_head.linear.bias",
"lm_head.linear.weight",
"lm_head.ln.bias",
"lm_head.ln.weight",
],
embed_weight_names=["lm_head.linear.weight", "transformer.embd.wte.weight"],
layer_prefix_format="transformer.h.{idx}",
layer_weight_suffixes=[
"ln.bias",
"ln.weight",
"mixer.out_proj.bias",
"mixer.out_proj.weight",
"mixer.Wqkv.bias",
"mixer.Wqkv.weight",
"mlp.fc1.bias",
"mlp.fc1.weight",
"mlp.fc2.bias",
"mlp.fc2.weight",
],
num_layers_key="n_layer",
)
PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticTensorNames(
name="PhiForCausalLM",
pre_weight_names=["model.embed_tokens.weight"],
post_weight_names=[
"lm_head.bias",
"lm_head.weight",
"model.final_layernorm.bias",
"model.final_layernorm.weight",
],
embed_weight_names=["lm_head.weight", "model.embed_tokens.weight"],
layer_prefix_format="model.layers.{idx}",
layer_weight_suffixes=[
"input_layernorm.bias",
"input_layernorm.weight",
"self_attn.dense.bias",
"self_attn.dense.weight",
"self_attn.q_proj.bias",
"self_attn.q_proj.weight",
"self_attn.k_proj.bias",
"self_attn.k_proj.weight",
"self_attn.v_proj.bias",
"self_attn.v_proj.weight",
"mlp.fc1.bias",
"mlp.fc1.weight",
"mlp.fc2.bias",
"mlp.fc2.weight",
],
)
BAICHUAN_INFO = StaticTensorNames(
name="BaichuanForCausalLM",
pre_weight_names=["model.embed_tokens.weight"],
post_weight_names=["model.norm.weight", "lm_head.weight"],
embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
layer_prefix_format="model.layers.{idx}",
layer_weight_suffixes=[
"input_layernorm.weight",
"self_attn.W_pack.weight",
"self_attn.o_proj.weight",
"post_attention_layernorm.weight",
"mlp.gate_proj.weight",
"mlp.down_proj.weight",
"mlp.up_proj.weight",
],
)
def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
if len(config.architectures) != 1:
raise RuntimeError("More than one architecture in config?")
arch_name = config.architectures[0]
if arch_name == PhiTensorNames.architecture_name:
return PhiTensorNames(config)
if arch_name == PHI2_INFO.name:
if config.model_type == "phi-msft":
return PHI2_INFO
elif config.model_type == "phi":
return PHI2_INFO_AGAIN_BUT_DIFFERENT
supported = [
LLAMA_INFO,
MISTRAL_INFO,
GPT_NEOX_INFO,
QWEN_INFO,
GPT2_INFO,
GPT2_SEQCLASS_INFO,
CHATGLM_INFO,
STABLELM_INFO,
JAIS_INFO,
BAICHUAN_INFO,
FALCON_INFO,
]
for arch in supported:
if arch.name == arch_name:
return arch
raise RuntimeError(f"Unsupported architecture {arch_name}")
- replace
configs.json
with the one from this repo - you need to add
modeling_phi.py
andconfigurations.phi
from this repo to your repo