paulilioaica commited on
Commit
ae3dbb5
1 Parent(s): e298d48

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -932
README.md CHANGED
@@ -109,947 +109,17 @@ Steps
109
 
110
  1. Modify moe_mixtral.py from `/content/mergekit/mergekit/scripts/mixtral_moe.py` to your hf repo
111
 
112
- ***mixtral_moe.py***
113
 
114
  ```
115
- # Copyright (C) 2024 Charles O. Goddard
116
- #
117
- # This software is free software: you can redistribute it and/or
118
- # modify it under the terms of the GNU Lesser General Public License as
119
- # published by the Free Software Foundation, either version 3 of the
120
- # License, or (at your option) any later version.
121
- #
122
- # This software is distributed in the hope that it will be useful, but
123
- # WITHOUT ANY WARRANTY; without even the implied warranty of
124
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
125
- # Lesser General Public License for more details.
126
- #
127
- # You should have received a copy of the GNU Lesser General Public License
128
- # along with this program. If not, see http://www.gnu.org/licenses/.
129
 
130
- import logging
131
- import os
132
- import sys
133
- from typing import Dict, List, Optional, Union
134
-
135
- import click
136
- import torch
137
- import tqdm
138
- import transformers
139
- import yaml
140
- from pydantic import BaseModel
141
- from transformers import (
142
- AutoModelForCausalLM,
143
- LlamaForCausalLM,
144
- MistralConfig,
145
- MistralForCausalLM,
146
- MixtralConfig,
147
- )
148
- from transformers.modeling_outputs import CausalLMOutputWithPast
149
-
150
- import mergekit.architecture
151
- from mergekit.common import ModelReference, dtype_from_name
152
- from mergekit.io import LazyTensorLoader, TensorWriter
153
- from mergekit.merge import MergeOptions
154
- from mergekit.options import add_merge_options
155
-
156
- # Create a Mixtral MoE from a set of equally-sized Mistral (or Llama) models.
157
- # Takes the path to a yml config and an output path.
158
- # Config schema is the two classes below.
159
-
160
-
161
- class Expert(BaseModel):
162
- source_model: str
163
-
164
- positive_prompts: List[str]
165
- negative_prompts: Optional[List[str]] = None
166
- noise_scale: Optional[float] = None
167
-
168
- @property
169
- def model_ref(self):
170
- return ModelReference.parse(self.source_model)
171
-
172
-
173
- class MistralMOEConfig(BaseModel):
174
- base_model: str
175
- experts: List[Expert]
176
- gate_mode: str = "hidden" # possible values: "hidden", "cheap_embed", "random"
177
- # "hidden" uses hidden state vectors for the given prompts for each layer
178
- # "cheap_embed" uses the average of token embeddings for the prompts, same for each layer
179
- # "random" is random
180
- dtype: Optional[str] = None
181
- experts_per_token: int = 2
182
-
183
-
184
- def get_hidden_states(
185
- model: Union[MistralForCausalLM, LlamaForCausalLM],
186
- tokenized: transformers.BatchEncoding,
187
- average: bool = True,
188
- ) -> List[torch.Tensor]:
189
- with torch.no_grad():
190
- output: CausalLMOutputWithPast = model(
191
- **tokenized.to(model.device), output_hidden_states=True, return_dict=True
192
- )
193
- hidden_states = torch.stack(
194
- output.hidden_states[:-1]
195
- ) # (num_layers, batch_size, seq_len, hidden_size)
196
- if average:
197
- # use average over sequence
198
- hidden_states = hidden_states.sum(dim=2) / hidden_states.shape[2]
199
- else:
200
- # take last value
201
- hidden_states = hidden_states[:, :, -1, :]
202
- return hidden_states.sum(dim=1) / hidden_states.shape[1]
203
-
204
-
205
- def get_cheap_embedding(
206
- embed: torch.Tensor,
207
- tokenized: Dict[str, torch.Tensor],
208
- num_layers: int,
209
- vocab_size: int,
210
- ) -> torch.Tensor:
211
- onehot = torch.nn.functional.one_hot(
212
- tokenized["input_ids"], num_classes=vocab_size
213
- ) # (batch_size, seq_len, 32000)
214
- h = onehot.float() @ embed.float() # (batch_size, seq_len, hidden_size)
215
- embedded = (
216
- (h * tokenized["attention_mask"].unsqueeze(-1))
217
- .sum(dim=1)
218
- .sum(dim=0, keepdim=True)
219
- ) # (1, hidden_size)
220
- res = embedded / embedded.norm(dim=-1, keepdim=True).clamp(
221
- min=1e-8
222
- ) # (1, hidden_size)
223
- return res.repeat(num_layers, 1)
224
-
225
-
226
- def tokenize_prompts(
227
- prompts: List[str], tokenizer: transformers.PreTrainedTokenizerBase
228
- ):
229
- return tokenizer(
230
- [tokenizer.bos_token + p for p in prompts],
231
- return_tensors="pt",
232
- padding=True,
233
- add_special_tokens=False,
234
- )
235
-
236
-
237
- def get_gate_params(
238
- model_ref: ModelReference,
239
- tokenizer: transformers.PreTrainedTokenizerBase,
240
- experts: List[Expert],
241
- mode: str = "hidden",
242
- load_in_4bit: bool = False,
243
- load_in_8bit: bool = False,
244
- lazy_unpickle: bool = False,
245
- trust_remote_code: bool = False,
246
- device: str = "auto",
247
- ):
248
- gate_vecs = []
249
- _do_it = None
250
-
251
- model_cfg = model_ref.config(trust_remote_code=trust_remote_code)
252
-
253
- if mode == "random":
254
- return torch.randn(
255
- (model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size)
256
- )
257
- elif mode == "cheap_embed":
258
- embed = LazyTensorLoader(
259
- model_ref.tensor_index(), lazy_unpickle=lazy_unpickle
260
- ).get_tensor("transformer.embd.wte.weight")
261
-
262
- def _do_it(tokenized):
263
- return get_cheap_embedding(
264
- embed,
265
- tokenized,
266
- num_layers=model_cfg.num_hidden_layers,
267
- vocab_size=model_cfg.vocab_size,
268
- )
269
-
270
- elif mode in ("hidden", "hidden_avg", "hidden_last"):
271
- model = AutoModelForCausalLM.from_pretrained(
272
- model_ref.model.path,
273
- revision=model_ref.model.revision,
274
- torch_dtype=torch.bfloat16,
275
- device_map=device,
276
- low_cpu_mem_usage=True,
277
- load_in_4bit=load_in_4bit,
278
- load_in_8bit=load_in_8bit,
279
- trust_remote_code=trust_remote_code,
280
- )
281
-
282
- def _do_it(tokenized):
283
- return get_hidden_states(
284
- model, tokenized=tokenized, average=mode == "hidden_avg"
285
- )
286
-
287
-
288
- gate_vecs = []
289
- print(experts)
290
- for expert in tqdm.tqdm(experts, desc="expert prompts"):
291
- print(_do_it)
292
- hidden_states = _do_it(tokenize_prompts(expert.positive_prompts, tokenizer))
293
- if expert.negative_prompts:
294
- hidden_states -= _do_it(
295
- tokenize_prompts(expert.negative_prompts, tokenizer)
296
- )
297
-
298
- hidden_states /= hidden_states.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8)
299
- gate_vecs.append(hidden_states)
300
- gate_vecs = torch.stack(gate_vecs, dim=0) # (num_expert, num_layer, hidden_size)
301
- return gate_vecs.permute(1, 0, 2)
302
-
303
-
304
- def warn_degenerate_gates(gate_vecs: torch.Tensor, threshold: float = 5.0):
305
- degen_indices = []
306
- num_layers, _num_experts, _hidden_size = gate_vecs.shape
307
- for idx in range(num_layers):
308
- c = torch.linalg.cond(gate_vecs[idx, :, :].float())
309
- if c > threshold:
310
- degen_indices.append(idx)
311
-
312
- if degen_indices:
313
- if len(degen_indices) == 1:
314
- layer_str = f"layer {degen_indices[0]}"
315
- verb = "has"
316
- elif len(degen_indices) == 2:
317
- layer_str = f"layers {' and '.join(map(str, degen_indices))}"
318
- verb = "have"
319
- elif len(degen_indices) >= num_layers:
320
- layer_str = "ALL layers"
321
- verb = "have"
322
- else:
323
- layer_str = (
324
- "layers "
325
- + ", ".join(map(str, degen_indices[:-1]))
326
- + ", and "
327
- + str(degen_indices[-1])
328
- )
329
- verb = "have"
330
-
331
- logging.warning(
332
- f"{layer_str} {verb} degenerate routing parameters "
333
- "- your prompts may be too similar."
334
- )
335
- logging.warning("One or more experts will be underutilized in your model.")
336
-
337
-
338
- def is_bad_config(config: MistralMOEConfig, allow_all_same: bool = False) -> bool:
339
- if len(config.experts) < 2:
340
- logging.error("Must include at least two experts.")
341
- return True
342
-
343
- if config.gate_mode == "random":
344
- return False # eh we're good
345
-
346
- def prompt_tup(e: Expert):
347
- return (tuple(e.positive_prompts), tuple(e.negative_prompts or []))
348
-
349
- # let's just nip this trend in the bud
350
- p_first = prompt_tup(config.experts[0])
351
- if all(prompt_tup(e) == p_first for e in config.experts[1:]):
352
- logging.error(
353
- "Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE."
354
- )
355
- logging.error(
356
- "For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert."
357
- )
358
- return True
359
-
360
- if not allow_all_same:
361
- if all(
362
- e.source_model == config.experts[0].source_model for e in config.experts[1:]
363
- ):
364
- logging.error(
365
- "All of your expert models are the same. This will produce "
366
- "a model that uses more resources but gives the exact same output. "
367
- "If you plan to train the model after merging, proceed with the "
368
- "--i-understand-this-is-not-useful-without-training flag."
369
- )
370
- return True
371
-
372
-
373
- def build(
374
- config: MistralMOEConfig,
375
- out_path: str,
376
- merge_options: MergeOptions,
377
- load_in_4bit: bool = False,
378
- load_in_8bit: bool = False,
379
- device: str = "auto",
380
- allow_all_same: bool = False,
381
- ):
382
- if is_bad_config(config, allow_all_same=allow_all_same):
383
- sys.exit(1)
384
-
385
- if config.experts_per_token < 1:
386
- logging.error("Experts per token must be >= 1")
387
- sys.exit(1)
388
- if config.experts_per_token > len(config.experts):
389
- logging.error("Experts per token must be <= number of experts")
390
- sys.exit(1)
391
-
392
- base_model = ModelReference.parse(config.base_model)
393
- base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code)
394
- if not isinstance(base_cfg, MistralConfig):
395
- base_cfg_mistral = MistralConfig(**base_cfg.to_dict())
396
- base_cfg_mistral.sliding_window = None
397
- base_cfg_mistral.max_position_embeddings = base_cfg.max_position_embeddings
398
- base_cfg = base_cfg_mistral
399
-
400
- out_cfg = MixtralConfig(**base_cfg.to_dict())
401
- out_cfg.architectures = ["PhiForCausalLM"]
402
- out_cfg.num_local_experts = len(config.experts)
403
- out_cfg.num_experts_per_tok = config.experts_per_token
404
- out_cfg.sliding_window = None
405
- if config.dtype:
406
- out_cfg.torch_dtype = config.dtype
407
- out_cfg.save_pretrained(out_path)
408
-
409
- if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0:
410
- logging.warning(
411
- f"Your model has {out_cfg.num_local_experts} experts, which is "
412
- "not a power of two. The model will not be usable in llama.cpp."
413
- )
414
-
415
- loaders: Dict[ModelReference, LazyTensorLoader] = {}
416
- for model in tqdm.tqdm(
417
- [base_model] + [e.model_ref for e in config.experts], desc="Warm up loaders"
418
- ):
419
- loaders[model] = LazyTensorLoader(
420
- model.tensor_index(cache_dir=merge_options.transformers_cache),
421
- lazy_unpickle=merge_options.lazy_unpickle,
422
- )
423
-
424
- base_loader = loaders.get(base_model)
425
- writer = TensorWriter(
426
- out_path=out_path,
427
- max_shard_size=merge_options.out_shard_size,
428
- safe_serialization=merge_options.safe_serialization,
429
- )
430
-
431
- if config.dtype:
432
- out_dtype = dtype_from_name(config.dtype)
433
- elif base_cfg.torch_dtype:
434
- out_dtype = base_cfg.torch_dtype
435
- if isinstance(out_dtype, str):
436
- out_dtype = dtype_from_name(out_dtype)
437
- else:
438
- out_dtype = None
439
-
440
- logging.info("Copying parameters...")
441
- MISTRAL_INFO = mergekit.architecture.PHI2_INFO
442
- for tensor_name in MISTRAL_INFO.pre_weight_names + MISTRAL_INFO.post_weight_names:
443
- tensor = base_loader.get_tensor(tensor_name)
444
- if not out_dtype:
445
- # All else has failed, take the first dtype we see
446
- out_dtype = tensor.dtype
447
- writer.save_tensor(
448
- tensor_name, tensor.to(dtype=out_dtype), clone=merge_options.clone_tensors
449
- )
450
- set_of_seen_tensors = set()
451
-
452
- for name_format in tqdm.tqdm(MISTRAL_INFO.layer_weight_formats()):
453
- for layer_idx in range(base_cfg.num_hidden_layers):
454
- tensor_name = name_format.format(idx=layer_idx)
455
- if ".mlp.fc" in name_format:
456
- for moe_index, expert in enumerate(config.experts):
457
- if tensor_name in set_of_seen_tensors:
458
- expert_name = tensor_name.replace(
459
- ".mlp.fc", f".moe.mlp.1.fc"
460
- )
461
- else:
462
- expert_name = tensor_name.replace(
463
- ".mlp.fc", f".moe.mlp.0.fc"
464
- )
465
- set_of_seen_tensors.add(tensor_name)
466
-
467
- expert_loader = loaders.get(expert.model_ref)
468
- tensor = expert_loader.get_tensor(tensor_name)
469
- if expert.noise_scale:
470
- tensor += torch.randn_like(tensor) * expert.noise_scale
471
- writer.save_tensor(
472
- expert_name, tensor.to(dtype=out_dtype), clone=True
473
- )
474
- print(expert_name, tensor_name)
475
- continue
476
- writer.save_tensor(
477
- tensor_name, base_loader.get_tensor(tensor_name).to(dtype=out_dtype)
478
- )
479
-
480
- tokenizer = transformers.AutoTokenizer.from_pretrained(
481
- base_model.model.path, revision=base_model.model.revision
482
- )
483
- tokenizer.padding_side = "left"
484
- tokenizer.pad_token_id = tokenizer.bos_token_id
485
-
486
- logging.info("Getting gate parameters...")
487
- gate_vecs = get_gate_params(
488
- base_model,
489
- tokenizer,
490
- config.experts,
491
- mode=config.gate_mode,
492
- load_in_4bit=load_in_4bit,
493
- load_in_8bit=load_in_8bit,
494
- lazy_unpickle=merge_options.lazy_unpickle,
495
- trust_remote_code=merge_options.trust_remote_code,
496
- device=device,
497
- )
498
- # gate_vecs: (num_layers, num_experts, hidden_size)
499
-
500
- warn_degenerate_gates(gate_vecs)
501
-
502
- for layer_idx in range(base_cfg.num_hidden_layers):
503
- writer.save_tensor(
504
- f"transformer.h.{layer_idx}.moe.gate.weight",
505
- gate_vecs[layer_idx, :, :].contiguous().to(dtype=out_dtype),
506
- )
507
- writer.finalize()
508
-
509
- if merge_options.copy_tokenizer:
510
- logging.info("Saving tokenizer...")
511
- tokenizer.save_pretrained(out_path, safe_serialization=True)
512
-
513
- logging.info("Done.")
514
-
515
-
516
- @click.command("mergekit-moe")
517
- @click.argument("config_path", type=click.Path(exists=True, dir_okay=False))
518
- @click.argument("out_path", type=click.Path())
519
- @click.option(
520
- "--load-in-4bit",
521
- is_flag=True,
522
- type=bool,
523
- default=False,
524
- help="Load model in 4bit for computing hidden states",
525
- )
526
- @click.option(
527
- "--load-in-8bit",
528
- is_flag=True,
529
- type=bool,
530
- default=False,
531
- help="Load model in 8bit for computing hidden states",
532
- )
533
- @click.option(
534
- "--device",
535
- type=str,
536
- default="auto",
537
- help="Device to use to compute embeddings",
538
- show_default=True,
539
- )
540
- @click.option(
541
- "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
542
- )
543
- @click.option(
544
- "--i-understand-this-is-not-useful-without-training",
545
- type=bool,
546
- default=False,
547
- is_flag=True,
548
- help="Really make the questionable model you want.",
549
- )
550
- @add_merge_options
551
- def main(
552
- config_path: str,
553
- out_path: str,
554
- load_in_4bit: bool,
555
- load_in_8bit: bool,
556
- device: str,
557
- merge_options: MergeOptions,
558
- verbose: bool,
559
- i_understand_this_is_not_useful_without_training: bool,
560
- ):
561
- logging.basicConfig(level=logging.INFO if verbose else logging.WARNING)
562
-
563
- if merge_options.cuda:
564
- logging.warning(
565
- '--cuda is a no-op for mergekit-moe, use "--device cuda" instead'
566
- )
567
-
568
- with open(config_path, "r", encoding="utf-8") as file:
569
- config_source = file.read()
570
-
571
- config = MistralMOEConfig.model_validate(yaml.safe_load(config_source))
572
- build(
573
- config,
574
- out_path=out_path,
575
- merge_options=merge_options,
576
- load_in_4bit=load_in_4bit,
577
- load_in_8bit=load_in_8bit,
578
- device=device,
579
- allow_all_same=i_understand_this_is_not_useful_without_training,
580
- )
581
-
582
- if merge_options.write_model_card:
583
- # TODO: generate a README.md as well
584
- with open(
585
- os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8"
586
- ) as fp:
587
- fp.write(config_source)
588
-
589
-
590
- if __name__ == "__main__":
591
- main()
592
  ```
593
 
594
  2. Modify architecture.py `/content/mergekit/mergekit/architecture.py`
595
  (this you can take from the link to the commit i have in description)
596
 
597
- ***architecture.py***
598
 
599
- ```
600
- # Copyright (C) 2024 Charles O. Goddard
601
- #
602
- # This software is free software: you can redistribute it and/or
603
- # modify it under the terms of the GNU Lesser General Public License as
604
- # published by the Free Software Foundation, either version 3 of the
605
- # License, or (at your option) any later version.
606
- #
607
- # This software is distributed in the hope that it will be useful, but
608
- # WITHOUT ANY WARRANTY; without even the implied warranty of
609
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
610
- # Lesser General Public License for more details.
611
- #
612
- # You should have received a copy of the GNU Lesser General Public License
613
- # along with this program. If not, see http://www.gnu.org/licenses/.
614
-
615
- from abc import ABC, abstractmethod
616
- from typing import List, Optional
617
-
618
- from pydantic import BaseModel
619
- from transformers import PretrainedConfig
620
-
621
-
622
- class ArchitectureInfo(ABC):
623
- @abstractmethod
624
- def pre_weights(self) -> List[str]:
625
- """Return a list of all weights preceding the first layer."""
626
- ...
627
-
628
- @abstractmethod
629
- def post_weights(self) -> List[str]:
630
- """Return a list of all weights following the final layer."""
631
- ...
632
-
633
- @abstractmethod
634
- def layer_weight_formats(self) -> List[str]:
635
- """Return a list of format strings all weights associated with a layer."""
636
- ...
637
-
638
- @abstractmethod
639
- def embed_weights(self) -> List[str]:
640
- ...
641
-
642
- def num_layers(self, config: PretrainedConfig) -> int:
643
- return config.num_hidden_layers
644
-
645
- def num_layers_config_key(self) -> str:
646
- """Key in config that represents number of layers"""
647
- return "num_hidden_layers"
648
-
649
-
650
- class StaticTensorNames(ArchitectureInfo, BaseModel, frozen=True):
651
- name: str
652
-
653
- pre_weight_names: List[str] # weights applied before first layer
654
- post_weight_names: List[str] # weights applied after last layer
655
- embed_weight_names: List[str] # weights for embed/lm_head
656
- layer_prefix_format: str
657
- layer_weight_suffixes: List[str]
658
- num_layers_key: Optional[str] = None
659
-
660
- def pre_weights(self) -> List[str]:
661
- return self.pre_weight_names
662
-
663
- def post_weights(self) -> List[str]:
664
- return self.post_weight_names
665
-
666
- def embed_weights(self) -> List[str]:
667
- return self.embed_weight_names
668
-
669
- def layer_weight_formats(self) -> List[str]:
670
- res = []
671
- for suffix in self.layer_weight_suffixes:
672
- res.append(self.layer_prefix_format + "." + suffix)
673
- return res
674
-
675
- def num_layers_config_key(self) -> str:
676
- if self.num_layers_key:
677
- return self.num_layers_key
678
- return super().num_layers_config_key()
679
-
680
- def num_layers(self, config: PretrainedConfig) -> int:
681
- return getattr(config, self.num_layers_config_key())
682
-
683
- def all_weights(self, config: PretrainedConfig) -> List[str]:
684
- num_layers = self.num_layers(config)
685
- tensor_names = list(self.pre_weights())
686
- for layer_idx in range(num_layers):
687
- for f in self.layer_weight_formats():
688
- tensor_names.append(f.format(idx=layer_idx))
689
- tensor_names.extend(self.post_weights())
690
- return tensor_names
691
-
692
-
693
- LLAMA_INFO = StaticTensorNames(
694
- name="LlamaForCausalLM",
695
- pre_weight_names=["model.embed_tokens.weight"],
696
- post_weight_names=["model.norm.weight", "lm_head.weight"],
697
- embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
698
- layer_prefix_format="model.layers.{idx}",
699
- layer_weight_suffixes=[
700
- "input_layernorm.weight",
701
- "mlp.up_proj.weight",
702
- "mlp.down_proj.weight",
703
- "mlp.gate_proj.weight",
704
- "post_attention_layernorm.weight",
705
- "self_attn.q_proj.weight",
706
- "self_attn.k_proj.weight",
707
- "self_attn.v_proj.weight",
708
- "self_attn.o_proj.weight",
709
- ],
710
- )
711
-
712
- MISTRAL_INFO = StaticTensorNames(
713
- name="MistralForCausalLM",
714
- # lol
715
- **LLAMA_INFO.model_dump(exclude=["name"]),
716
- )
717
-
718
-
719
- STABLELM_INFO = StaticTensorNames(
720
- name="StableLMEpochForCausalLM",
721
- post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"],
722
- layer_weight_suffixes=LLAMA_INFO.layer_weight_suffixes
723
- + [
724
- "input_layernorm.bias",
725
- "post_attention_layernorm.bias",
726
- ],
727
- **LLAMA_INFO.model_dump(
728
- exclude=["name", "layer_weight_suffixes", "post_weight_names"]
729
- ),
730
- )
731
-
732
- GPT_NEOX_INFO = StaticTensorNames(
733
- name="GPTNeoXForCausalLM",
734
- pre_weight_names=["gpt_neox.embed_in.weight"],
735
- post_weight_names=[
736
- "gpt_neox.final_layer_norm.bias",
737
- "gpt_neox.final_layer_norm.weight",
738
- "embed_out.weight",
739
- ],
740
- embed_weight_names=["gpt_neox.embed_in.weight", "embed_out.weight"],
741
- layer_prefix_format="gpt_neox.layers.{idx}",
742
- layer_weight_suffixes=sum(
743
- (
744
- [f"{prefix}.weight", f"{prefix}.bias"]
745
- for prefix in [
746
- "attention.dense",
747
- "attention.query_key_value",
748
- "input_layernorm",
749
- "mlp.dense_4h_to_h",
750
- "mlp.dense_h_to_4h",
751
- "post_attention_layernorm",
752
- ]
753
- ),
754
- start=[],
755
- )
756
- + ["attention.bias", "attention.masked_bias", "attention.rotary_emb.inv_freq"],
757
- )
758
-
759
- GPT2_INFO = StaticTensorNames(
760
- name="GPT2LMHeadModel",
761
- pre_weight_names=["wte.weight", "wpe.weight"],
762
- post_weight_names=["ln_f.weight", "ln_f.bias"],
763
- embed_weight_names=["wte.weight"],
764
- layer_prefix_format="h.{idx}",
765
- layer_weight_suffixes=[
766
- "attn.c_attn.weight",
767
- "attn.c_attn.bias",
768
- "attn.c_proj.weight",
769
- "attn.c_proj.bias",
770
- "ln_1.weight",
771
- "ln_1.bias",
772
- "ln_2.weight",
773
- "ln_2.bias",
774
- "mlp.c_proj.weight",
775
- "mlp.c_proj.bias",
776
- "mlp.c_fc.weight",
777
- "mlp.c_fc.bias",
778
- "mlp.c_proj.weight",
779
- "mlp.c_proj.bias",
780
- ],
781
- num_layers_key="n_layer",
782
- )
783
-
784
- JAIS_INFO = StaticTensorNames(
785
- name="JAISLMHeadModel",
786
- pre_weight_names=["transformer.wte.weight", "transformer.relative_pe.slopes"],
787
- post_weight_names=["transformer.ln_f.weight", "transformer.ln_f.bias"],
788
- embed_weight_names=["transformer.wte.weight"],
789
- layer_prefix_format="transformer.h.{idx}",
790
- layer_weight_suffixes=[
791
- "attn.c_attn.weight",
792
- "attn.c_attn.bias",
793
- "attn.c_proj.weight",
794
- "attn.c_proj.bias",
795
- "ln_1.weight",
796
- "ln_1.bias",
797
- "ln_2.weight",
798
- "ln_2.bias",
799
- "mlp.c_fc.weight",
800
- "mlp.c_fc.bias",
801
- "mlp.c_fc2.weight",
802
- "mlp.c_fc2.bias",
803
- "mlp.c_proj.weight",
804
- "mlp.c_proj.bias",
805
- ],
806
- num_layers_key="n_layer",
807
- )
808
-
809
- GPT2_SEQCLASS_INFO = StaticTensorNames(
810
- name="GPT2ForSequenceClassification",
811
- pre_weight_names=["transformer.wte.weight", "transformer.wpe.weight"],
812
- post_weight_names=[
813
- "transformer.ln_f.weight",
814
- "transformer.ln_f.bias",
815
- "score.weight",
816
- ],
817
- layer_prefix_format="transformer.h.{idx}",
818
- embed_weight_names=GPT2_INFO.embed_weight_names,
819
- layer_weight_suffixes=GPT2_INFO.layer_weight_suffixes,
820
- num_layers_key=GPT2_INFO.num_layers_key,
821
- )
822
-
823
-
824
- QWEN_INFO = StaticTensorNames(
825
- name="QWenLMHeadModel",
826
- pre_weight_names=["transformer.wte.weight"],
827
- post_weight_names=["transformer.ln_f.weight", "lm_head.weight"],
828
- embed_weight_names=["transformer.wte.weight", "lm_head.weight"],
829
- layer_prefix_format="transformer.h.{idx}",
830
- layer_weight_suffixes=[
831
- "attn.c_attn.bias",
832
- "attn.c_attn.weight",
833
- "attn.c_proj.weight",
834
- "ln_1.weight",
835
- "ln_2.weight",
836
- "mlp.c_proj.weight",
837
- "mlp.w1.weight",
838
- "mlp.w2.weight",
839
- ],
840
- )
841
-
842
- CHATGLM_INFO = StaticTensorNames(
843
- name="ChatGLMModel",
844
- pre_weight_names=[
845
- "transformer.embedding.word_embeddings.weight",
846
- "transformer.rotary_pos_emb.inv_freq",
847
- ],
848
- post_weight_names=[
849
- "transformer.encoder.final_layernorm.weight",
850
- "transformer.output_layer.weight",
851
- ],
852
- embed_weight_names=[
853
- "transformer.embedding.word_embeddings.weight",
854
- "transformer.output_layer.weight",
855
- ],
856
- layer_prefix_format="transformer.encoder.layers.{idx}",
857
- layer_weight_suffixes=[
858
- "input_layernorm.weight",
859
- "mlp.dense_4h_to_h.weight",
860
- "mlp.dense_h_to_4h.weight",
861
- "post_attention_layernorm.weight",
862
- "self_attention.dense.weight",
863
- "self_attention.query_key_value.bias",
864
- "self_attention.query_key_value.weight",
865
- ],
866
- )
867
-
868
- FALCON_INFO = StaticTensorNames(
869
- name="FalconForCausalLM",
870
- pre_weight_names=["transformer.word_embeddings.weight"],
871
- post_weight_names=[
872
- "transformer.ln_f.weight",
873
- "transformer.ln_f.bias",
874
- "lm_head.weight",
875
- ],
876
- embed_weight_names=["transformer.word_embeddings.weight", "lm_head.weight"],
877
- layer_prefix_format="transformer.h.{idx}",
878
- layer_weight_suffixes=[
879
- "ln_attn.bias",
880
- "ln_attn.weight",
881
- "ln_mlp.bias",
882
- "ln_mlp.weight",
883
- "mlp.dense_4h_to_h.weight",
884
- "mlp.dense_h_to_4h.weight",
885
- "self_attention.dense.weight",
886
- "self_attention.query_key_value.weight",
887
- ],
888
- )
889
-
890
-
891
- class PhiTensorNames(ArchitectureInfo):
892
- architecture_name: str = "MixFormerSequentialForCausalLM"
893
-
894
- def __init__(self, config: PretrainedConfig):
895
- self.config = config
896
-
897
- def __eq__(self, rhs: "PhiTensorNames"):
898
- if not isinstance(rhs, PhiTensorNames):
899
- return False
900
- return self.num_layers() == rhs.num_layers()
901
-
902
- def pre_weights(self) -> List[str]:
903
- return ["layers.0.wte.weight"]
904
-
905
- def post_weights(self) -> List[str]:
906
- fake_layer_idx = self.config.n_layer + 1
907
- return [
908
- f"layers.{fake_layer_idx}.{suffix}"
909
- for suffix in ["linear.bias", "linear.weight", "ln.bias", "ln.weight"]
910
- ]
911
-
912
- def embed_weights(self) -> List[str]:
913
- fake_layer_idx = self.config.n_layer + 1
914
- return [
915
- "layers.0.wte.weight",
916
- f"layers.{fake_layer_idx}.linear.weight",
917
- f"layers.{fake_layer_idx}.linear.bias",
918
- ]
919
-
920
- def layer_weight_formats(self) -> List[str]:
921
- return [
922
- ("layers.{idx}." + suffix)
923
- for suffix in [
924
- "ln.bias",
925
- "ln.weight",
926
- "mixer.Wqkv.bias",
927
- "mixer.Wqkv.weight",
928
- "mixer.out_proj.bias",
929
- "mixer.out_proj.weight",
930
- "mixer.rotary_emb.inv_freq",
931
- "mlp.fc1.bias",
932
- "mlp.fc1.weight",
933
- "mlp.fc2.bias",
934
- "mlp.fc2.weight",
935
- ]
936
- ]
937
-
938
- def num_layers(self, config: PretrainedConfig) -> int:
939
- return config.n_layer
940
-
941
- def num_layers_config_key(self) -> str:
942
- return "n_layer"
943
-
944
-
945
- PHI2_INFO = StaticTensorNames(
946
- name="PhiForCausalLM",
947
- pre_weight_names=["transformer.embd.wte.weight"],
948
- post_weight_names=[
949
- "lm_head.linear.bias",
950
- "lm_head.linear.weight",
951
- "lm_head.ln.bias",
952
- "lm_head.ln.weight",
953
- ],
954
- embed_weight_names=["lm_head.linear.weight", "transformer.embd.wte.weight"],
955
- layer_prefix_format="transformer.h.{idx}",
956
- layer_weight_suffixes=[
957
- "ln.bias",
958
- "ln.weight",
959
- "mixer.out_proj.bias",
960
- "mixer.out_proj.weight",
961
- "mixer.Wqkv.bias",
962
- "mixer.Wqkv.weight",
963
- "mlp.fc1.bias",
964
- "mlp.fc1.weight",
965
- "mlp.fc2.bias",
966
- "mlp.fc2.weight",
967
- ],
968
- num_layers_key="n_layer",
969
- )
970
-
971
-
972
- PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticTensorNames(
973
- name="PhiForCausalLM",
974
- pre_weight_names=["model.embed_tokens.weight"],
975
- post_weight_names=[
976
- "lm_head.bias",
977
- "lm_head.weight",
978
- "model.final_layernorm.bias",
979
- "model.final_layernorm.weight",
980
- ],
981
- embed_weight_names=["lm_head.weight", "model.embed_tokens.weight"],
982
- layer_prefix_format="model.layers.{idx}",
983
- layer_weight_suffixes=[
984
- "input_layernorm.bias",
985
- "input_layernorm.weight",
986
- "self_attn.dense.bias",
987
- "self_attn.dense.weight",
988
- "self_attn.q_proj.bias",
989
- "self_attn.q_proj.weight",
990
- "self_attn.k_proj.bias",
991
- "self_attn.k_proj.weight",
992
- "self_attn.v_proj.bias",
993
- "self_attn.v_proj.weight",
994
- "mlp.fc1.bias",
995
- "mlp.fc1.weight",
996
- "mlp.fc2.bias",
997
- "mlp.fc2.weight",
998
- ],
999
- )
1000
-
1001
-
1002
- BAICHUAN_INFO = StaticTensorNames(
1003
- name="BaichuanForCausalLM",
1004
- pre_weight_names=["model.embed_tokens.weight"],
1005
- post_weight_names=["model.norm.weight", "lm_head.weight"],
1006
- embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"],
1007
- layer_prefix_format="model.layers.{idx}",
1008
- layer_weight_suffixes=[
1009
- "input_layernorm.weight",
1010
- "self_attn.W_pack.weight",
1011
- "self_attn.o_proj.weight",
1012
- "post_attention_layernorm.weight",
1013
- "mlp.gate_proj.weight",
1014
- "mlp.down_proj.weight",
1015
- "mlp.up_proj.weight",
1016
- ],
1017
- )
1018
-
1019
-
1020
- def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames:
1021
- if len(config.architectures) != 1:
1022
- raise RuntimeError("More than one architecture in config?")
1023
-
1024
- arch_name = config.architectures[0]
1025
- if arch_name == PhiTensorNames.architecture_name:
1026
- return PhiTensorNames(config)
1027
-
1028
- if arch_name == PHI2_INFO.name:
1029
- if config.model_type == "phi-msft":
1030
- return PHI2_INFO
1031
- elif config.model_type == "phi":
1032
- return PHI2_INFO_AGAIN_BUT_DIFFERENT
1033
-
1034
- supported = [
1035
- LLAMA_INFO,
1036
- MISTRAL_INFO,
1037
- GPT_NEOX_INFO,
1038
- QWEN_INFO,
1039
- GPT2_INFO,
1040
- GPT2_SEQCLASS_INFO,
1041
- CHATGLM_INFO,
1042
- STABLELM_INFO,
1043
- JAIS_INFO,
1044
- BAICHUAN_INFO,
1045
- FALCON_INFO,
1046
- ]
1047
- for arch in supported:
1048
- if arch.name == arch_name:
1049
- return arch
1050
-
1051
- raise RuntimeError(f"Unsupported architecture {arch_name}")
1052
- ```
1053
 
1054
  3) replace `configs.json` with the one from **this repo**
1055
  4) you need to add `modeling_phi.py` and `configurations.phi` from **this repo** to your repo
 
109
 
110
  1. Modify moe_mixtral.py from `/content/mergekit/mergekit/scripts/mixtral_moe.py` to your hf repo
111
 
112
+ [***mixtral_moe.py***](https://github.com/paulilioaica/Phi-MOE/blob/main/mixtral_moe.py)
113
 
114
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  ```
117
 
118
  2. Modify architecture.py `/content/mergekit/mergekit/architecture.py`
119
  (this you can take from the link to the commit i have in description)
120
 
121
+ [***architecture.py***](https://github.com/paulilioaica/Phi-MOE/blob/main/architecture.py)
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  3) replace `configs.json` with the one from **this repo**
125
  4) you need to add `modeling_phi.py` and `configurations.phi` from **this repo** to your repo