release version
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- README.md +160 -0
- config.json +30 -0
- configuration_japanese_stablelm_alpha.py +120 -0
- generation_config.json +6 -0
- japanese-stablelm-parrot.jpg +3 -0
- modeling_japanese_stablelm_alpha.py +682 -0
- pytorch_model-00001-of-00003.bin +3 -0
- pytorch_model-00002-of-00003.bin +3 -0
- pytorch_model-00003-of-00003.bin +3 -0
- pytorch_model.bin.index.json +267 -0
- requirements.txt +2 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
README.md
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- ja
|
4 |
+
tags:
|
5 |
+
- japanese-stablelm
|
6 |
+
- causal-lm
|
7 |
+
pipeline_tag: text-generation
|
8 |
+
datasets:
|
9 |
+
- wikipedia
|
10 |
+
- mc4
|
11 |
+
- cc100
|
12 |
+
- oscar-corpus/OSCAR-2301
|
13 |
+
- oscar-corpus/OSCAR-2201
|
14 |
+
- togethercomputer/RedPajama-Data-1T
|
15 |
+
license:
|
16 |
+
- apache-2.0
|
17 |
+
---
|
18 |
+
|
19 |
+
# Japanese-StableLM-Base-Alpha-7B
|
20 |
+
|
21 |
+
![japanese-stablelm-icon](./japanese-stablelm-parrot.jpg)
|
22 |
+
|
23 |
+
> "A parrot able to speak Japanese, ukiyoe, edo period" — [Stable Diffusion XL](https://clipdrop.co/stable-diffusion)
|
24 |
+
|
25 |
+
## Model Description
|
26 |
+
|
27 |
+
`japanese-stablelm-base-alpha-7b` is a 7B-parameter decoder-only language model pre-trained on a diverse collection of Japanese and English datasets which focus on maximizing Japanese language modeling performance and Japanese downstream task performance.
|
28 |
+
|
29 |
+
For an instruction-following model, check [Japanese-StableLM-Instruct-Alpha-7B](https://huggingface.co/stabilityai/japanese-stablelm-instruct-alpha-7b) and get access by accepting the terms and conditions.
|
30 |
+
|
31 |
+
## Usage
|
32 |
+
|
33 |
+
First install additional dependencies in [requirements.txt](./requirements.txt):
|
34 |
+
|
35 |
+
```sh
|
36 |
+
pip install sentencepiece einops
|
37 |
+
```
|
38 |
+
|
39 |
+
Then start generating text with `japanese-stablelm-base-alpha-7b` by using the following code snippet:
|
40 |
+
|
41 |
+
```python
|
42 |
+
import torch
|
43 |
+
from transformers import LlamaTokenizer, AutoModelForCausalLM
|
44 |
+
|
45 |
+
tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1")
|
46 |
+
|
47 |
+
model = AutoModelForCausalLM.from_pretrained(
|
48 |
+
"stabilityai/japanese-stablelm-base-alpha-7b",
|
49 |
+
trust_remote_code=True,
|
50 |
+
)
|
51 |
+
model.half()
|
52 |
+
|
53 |
+
if torch.cuda.is_available():
|
54 |
+
model = model.to("cuda")
|
55 |
+
|
56 |
+
prompt = """
|
57 |
+
AI で科学研究を加速するには、
|
58 |
+
""".strip()
|
59 |
+
|
60 |
+
input_ids = tokenizer.encode(
|
61 |
+
prompt,
|
62 |
+
add_special_tokens=False,
|
63 |
+
return_tensors="pt"
|
64 |
+
)
|
65 |
+
|
66 |
+
# this is for reproducibility.
|
67 |
+
# free free to change to get different result
|
68 |
+
seed = 23
|
69 |
+
torch.manual_seed(seed)
|
70 |
+
|
71 |
+
tokens = model.generate(
|
72 |
+
input_ids.to(device=model.device),
|
73 |
+
max_new_tokens=128,
|
74 |
+
temperature=1,
|
75 |
+
top_p=0.95,
|
76 |
+
do_sample=True,
|
77 |
+
)
|
78 |
+
|
79 |
+
out = tokenizer.decode(tokens[0], skip_special_tokens=False)
|
80 |
+
print(out)
|
81 |
+
"""
|
82 |
+
AI で科学研究を加速するには、データ駆動型文化が必要であることも明らかになってきています。研究のあらゆる側面で、データがより重要になっているのです。
|
83 |
+
20 世紀の科学は、研究者が直接研究を行うことで、研究データを活用してきました。その後、多くの科学分野ではデータは手動で分析されるようになったものの、これらの方法には多大なコストと労力がかかることが分かりました。 そこで、多くの研究者や研究者グループは、より効率的な手法を開発し、研究の規模を拡大してきました。21 世紀になると、研究者が手動で実施する必要のある研究は、その大部分を研究者が自動化できるようになりました。
|
84 |
+
"""
|
85 |
+
```
|
86 |
+
|
87 |
+
We suggest playing with different generation config (`top_p`, `repetition_penalty` etc) to find the best setup for your tasks. For example, use higher temperature for roleplay task, lower temperature for reasoning.
|
88 |
+
|
89 |
+
## Model Details
|
90 |
+
|
91 |
+
* **Model type**: `japanese-stablelm-base-alpha-7b` model is an auto-regressive language model based on the NeoX transformer architecture.
|
92 |
+
* **Language(s)**: Japanese
|
93 |
+
* **Library**: [GPT-NeoX](https://github.com/EleutherAI/gpt-neox)
|
94 |
+
* **License**: This model is licensed under [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0).
|
95 |
+
|
96 |
+
|
97 |
+
## Training
|
98 |
+
|
99 |
+
| Parameters | Hidden Size | Layers | Heads | Sequence Length |
|
100 |
+
|------------|-------------|--------|-------|-----------------|
|
101 |
+
| 7B | 4096 | 32 | 32 | 2048 |
|
102 |
+
|
103 |
+
### Training Dataset
|
104 |
+
|
105 |
+
`japanese-stablelm-base-alpha-7b` is pre-trained on around 750B tokens from a mixture of the following corpora:
|
106 |
+
|
107 |
+
- [Japanese/English Wikipedia](https://dumps.wikimedia.org/other/cirrussearch)
|
108 |
+
- [Japanese mc4](https://huggingface.co/datasets/mc4)
|
109 |
+
- [Japanese CC-100](http://data.statmt.org/cc-100/ja.txt.xz)
|
110 |
+
- [Japanese OSCAR](https://oscar-project.github.io/documentation/)
|
111 |
+
- [RedPajama](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T)
|
112 |
+
|
113 |
+
## Use and Limitations
|
114 |
+
|
115 |
+
### Intended Use
|
116 |
+
|
117 |
+
The model is intended to be used by all individuals as foundational models for application-specific fine-tuning without strict limitations on commercial use.
|
118 |
+
|
119 |
+
### Limitations and bias
|
120 |
+
|
121 |
+
The pre-training dataset may have contained offensive or inappropriate content even after applying data cleansing filters which can be reflected in the model generated text. We recommend users exercise reasonable caution when using these models in production systems. Do not use the model for any applications that may cause harm or distress to individuals or groups.
|
122 |
+
|
123 |
+
## Authors
|
124 |
+
- [Meng Lee](https://huggingface.co/leemeng)
|
125 |
+
- [Fujiki Nakamura](https://huggingface.co/fujiki)
|
126 |
+
- [Makoto Shing](https://huggingface.co/mkshing)
|
127 |
+
- [Paul McCann](https://huggingface.co/polm-stability)
|
128 |
+
- [Takuya Akiba](https://huggingface.co/iwiwi)
|
129 |
+
- [Naoki Orii](https://huggingface.co/mrorii)
|
130 |
+
|
131 |
+
## Acknowledgements
|
132 |
+
|
133 |
+
We are utilizing the v1 version of the [novelai-tokenizer](https://github.com/NovelAI/novelai-tokenizer), introduced by [NovelAI](https://novelai.net/), because it processes both Japanese and English text effectively and efficiently. We extend our gratitude to NovelAI for allowing us to use their remarkable work. For more details about the tokenizer, please refer to their [blog post](https://blog.novelai.net/novelais-new-llm-tokenizer-5bc140e17642).
|
134 |
+
|
135 |
+
We are grateful for the contributions of the EleutherAI Polyglot-JA team in helping us to collect a large amount of pre-training data in Japanese. Polyglot-JA members includes Kevin (Project Lead), Fujiki (originally started this project when he commited to the Polyglot team), Yunho, Minji and Su-Kyeong Jang.
|
136 |
+
|
137 |
+
We are also appreciative of [AI Novelist/Sta (Bit192, Inc.)](https://ai-novel.com/index.php) and the numerous contributors from [Stable Community Japan](https://discord.gg/VPrcE475HB) for assisting us in gathering a large amount of high-quality Japanese textual data for model training.
|
138 |
+
|
139 |
+
## Citations
|
140 |
+
|
141 |
+
```bibtext
|
142 |
+
@software{gpt-neox-library,
|
143 |
+
title = {{GPT-NeoX: Large Scale Autoregressive Language Modeling in PyTorch}},
|
144 |
+
author = {Andonian, Alex and Anthony, Quentin and Biderman, Stella and Black, Sid and Gali, Preetham and Gao, Leo and Hallahan, Eric and Levy-Kramer, Josh and Leahy, Connor and Nestler, Lucas and Parker, Kip and Pieler, Michael and Purohit, Shivanshu and Songz, Tri and Phil, Wang and Weinbach, Samuel},
|
145 |
+
url = {https://www.github.com/eleutherai/gpt-neox},
|
146 |
+
doi = {10.5281/zenodo.5879544},
|
147 |
+
month = {8},
|
148 |
+
year = {2021},
|
149 |
+
version = {0.0.1},
|
150 |
+
}
|
151 |
+
```
|
152 |
+
|
153 |
+
## How to cite
|
154 |
+
```
|
155 |
+
@misc{JapaneseStableLMBaseAlpha7B,
|
156 |
+
url={[https://huggingface.co/stabilityai/japanese-stablelm-base-alpha-7b](https://huggingface.co/stabilityai/japanese-stablelm-base-alpha-7b)},
|
157 |
+
title={Japanese StableLM Base Alpha 7B},
|
158 |
+
author={Lee, Meng and Nakamura, Fujiki and Shing, Makoto and McCann, Paul and Akiba, Takuya and Orii, Naoki}
|
159 |
+
}
|
160 |
+
```
|
config.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "stabilityai/japanese-stablelm-base-alpha-7b",
|
3 |
+
"architectures": [
|
4 |
+
"JapaneseStableLMAlphaForCausalLM"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "stabilityai/japanese-stablelm-base-alpha-7b--configuration_japanese_stablelm_alpha.JapaneseStableLMAlphaConfig",
|
8 |
+
"AutoModelForCausalLM": "stabilityai/japanese-stablelm-base-alpha-7b--modeling_japanese_stablelm_alpha.JapaneseStableLMAlphaForCausalLM"
|
9 |
+
},
|
10 |
+
"bos_token_id": 3,
|
11 |
+
"classifier_dropout": 0.1,
|
12 |
+
"eos_token_id": 3,
|
13 |
+
"hidden_act": "silu",
|
14 |
+
"hidden_size": 4096,
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"layer_norm_eps": 1e-05,
|
17 |
+
"max_position_embeddings": 2048,
|
18 |
+
"num_attention_heads": 32,
|
19 |
+
"num_hidden_layers": 32,
|
20 |
+
"rotary_emb_base": 10000,
|
21 |
+
"rotary_pct": 0.25,
|
22 |
+
"rotary_scale_base": 512,
|
23 |
+
"tie_word_embeddings": false,
|
24 |
+
"torch_dtype": "float32",
|
25 |
+
"transformers_version": "4.30.2",
|
26 |
+
"use_bias_in_mlp": false,
|
27 |
+
"use_cache": true,
|
28 |
+
"use_parallel_residual": true,
|
29 |
+
"vocab_size": 65536
|
30 |
+
}
|
configuration_japanese_stablelm_alpha.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" JapaneseStableLMAlpha model configuration"""
|
16 |
+
|
17 |
+
from transformers import PretrainedConfig
|
18 |
+
from transformers.utils import logging
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.get_logger(__name__)
|
22 |
+
|
23 |
+
STABLE_LM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
24 |
+
|
25 |
+
|
26 |
+
class JapaneseStableLMAlphaConfig(PretrainedConfig):
|
27 |
+
r"""
|
28 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
29 |
+
documentation from [`PretrainedConfig`] for more information.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
vocab_size (`int`, *optional*, defaults to 65536):
|
33 |
+
Vocabulary size of the JapaneseStableLMAlphaModel. Defines the number of different tokens that
|
34 |
+
can be represented by the `inputs_ids` passed when calling [`JapaneseStableLMAlphaModel`].
|
35 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
36 |
+
Dimension of the decoder layers and the pooler layer.
|
37 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
38 |
+
Number of hidden layers in the Transformer decoder.
|
39 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
40 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
41 |
+
intermediate_size (`int`, *optional*, defaults to 16384):
|
42 |
+
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer decoder.
|
43 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
44 |
+
The non-linear activation function (function or string).
|
45 |
+
rotary_pct (`float`, *optional*, defaults to 0.25):
|
46 |
+
Percentage of hidden dimensions to allocate to rotary embeddings.
|
47 |
+
rotary_emb_base (`int`, *optional*, defaults to 10000)
|
48 |
+
Base for computing rotary embeddings frequency.
|
49 |
+
rotary_scale_base (`int`, *optional*, defaults to 512)
|
50 |
+
Base `scale` for computing XPos rotary embeddings scale.
|
51 |
+
classifier_dropout (`float`, *optional*, defaults to 0.1):
|
52 |
+
Argument used when doing token classification, used in the model
|
53 |
+
[`StableLMForTokenClassification`]. The dropout ratio for the hidden layer.
|
54 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
55 |
+
The maximum sequence length that this model might ever be used with.
|
56 |
+
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
57 |
+
initializer_range (`float`, *optional*, defaults to 1e-5):
|
58 |
+
The standard deviation of the truncated_normal_initializer for initializing
|
59 |
+
all weight matrices.
|
60 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
61 |
+
The epsilon used by the layer normalization layers.
|
62 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
63 |
+
Whether or not the model should return the last key/values attentions
|
64 |
+
(not used by all models). Only relevant if `config.is_decoder=True`.
|
65 |
+
use_parallel_residual (`bool`, *optional*, defaults to `True`):
|
66 |
+
Whether to use a "parallel" formulation in each Transformer layer,
|
67 |
+
which can provide a slight training speedup at large scales.
|
68 |
+
Example:
|
69 |
+
|
70 |
+
```python
|
71 |
+
>>> from transformers import JapaneseStableLMAlphaConfig, JapaneseStableLMAlphaModel
|
72 |
+
|
73 |
+
>>> # Initializing a JapaneseStableLMAlpha style configuration
|
74 |
+
>>> configuration = JapaneseStableLMAlphaConfig()
|
75 |
+
|
76 |
+
>>> # Initializing a model (with random weights) from the style configuration
|
77 |
+
>>> model = JapaneseStableLMAlphaModel(configuration) # doctest: +SKIP
|
78 |
+
|
79 |
+
>>> # Accessing the model configuration
|
80 |
+
>>> configuration = model.config # doctest: +SKIP
|
81 |
+
```"""
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
vocab_size=65536,
|
85 |
+
hidden_size=4096,
|
86 |
+
num_hidden_layers=32,
|
87 |
+
num_attention_heads=32,
|
88 |
+
hidden_act="silu",
|
89 |
+
rotary_pct=0.25,
|
90 |
+
rotary_emb_base=10000,
|
91 |
+
rotary_scale_base=512,
|
92 |
+
classifier_dropout=0.1,
|
93 |
+
max_position_embeddings=2048,
|
94 |
+
initializer_range=0.02,
|
95 |
+
layer_norm_eps=1e-5,
|
96 |
+
use_cache=True,
|
97 |
+
bos_token_id=3,
|
98 |
+
eos_token_id=3,
|
99 |
+
tie_word_embeddings=False,
|
100 |
+
use_parallel_residual=True,
|
101 |
+
use_bias_in_mlp=True,
|
102 |
+
**kwargs,
|
103 |
+
):
|
104 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
105 |
+
self.vocab_size = vocab_size
|
106 |
+
self.max_position_embeddings = max_position_embeddings
|
107 |
+
self.hidden_size = hidden_size
|
108 |
+
self.num_hidden_layers = num_hidden_layers
|
109 |
+
self.num_attention_heads = num_attention_heads
|
110 |
+
self.hidden_act = hidden_act
|
111 |
+
self.rotary_pct = rotary_pct
|
112 |
+
self.rotary_emb_base = rotary_emb_base
|
113 |
+
self.rotary_scale_base = rotary_scale_base
|
114 |
+
self.classifier_dropout = classifier_dropout
|
115 |
+
self.initializer_range = initializer_range
|
116 |
+
self.layer_norm_eps = layer_norm_eps
|
117 |
+
self.use_cache = use_cache
|
118 |
+
self.tie_word_embeddings = tie_word_embeddings
|
119 |
+
self.use_parallel_residual = use_parallel_residual
|
120 |
+
self.use_bias_in_mlp = use_bias_in_mlp
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 3,
|
4 |
+
"eos_token_id": 3,
|
5 |
+
"transformers_version": "4.30.2"
|
6 |
+
}
|
japanese-stablelm-parrot.jpg
ADDED
Git LFS Details
|
modeling_japanese_stablelm_alpha.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch JapaneseStableLMAlpha model. """
|
16 |
+
from typing import Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.utils.checkpoint
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import CrossEntropyLoss
|
22 |
+
from transformers.modeling_outputs import (
|
23 |
+
BaseModelOutputWithPast,
|
24 |
+
CausalLMOutputWithPast,
|
25 |
+
)
|
26 |
+
from transformers.modeling_utils import PreTrainedModel
|
27 |
+
from transformers.utils import logging
|
28 |
+
from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
class JapaneseStableLMAlphaPreTrainedModel(PreTrainedModel):
|
35 |
+
"""
|
36 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
37 |
+
models.
|
38 |
+
"""
|
39 |
+
|
40 |
+
config_class = JapaneseStableLMAlphaConfig
|
41 |
+
base_model_prefix = "transformer"
|
42 |
+
supports_gradient_checkpointing = True
|
43 |
+
_no_split_modules = ["DecoderLayer"]
|
44 |
+
_skip_keys_device_placement = "past_key_values"
|
45 |
+
|
46 |
+
def _init_weights(self, module):
|
47 |
+
"""Initialize the weights"""
|
48 |
+
if isinstance(module, nn.Linear):
|
49 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
50 |
+
if module.bias is not None:
|
51 |
+
module.bias.data.zero_()
|
52 |
+
elif isinstance(module, nn.Embedding):
|
53 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
54 |
+
if module.padding_idx is not None:
|
55 |
+
module.weight.data[module.padding_idx].zero_()
|
56 |
+
elif isinstance(module, nn.LayerNorm):
|
57 |
+
if module.bias is not None:
|
58 |
+
module.bias.data.zero_()
|
59 |
+
if module.weight is not None:
|
60 |
+
module.weight.data.fill_(1.0)
|
61 |
+
|
62 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
63 |
+
if isinstance(module, JapaneseStableLMAlphaModel):
|
64 |
+
module.gradient_checkpointing = value
|
65 |
+
|
66 |
+
|
67 |
+
class JapaneseStableLMAlphaModel(JapaneseStableLMAlphaPreTrainedModel):
|
68 |
+
def __init__(self, config):
|
69 |
+
super().__init__(config)
|
70 |
+
self.config = config
|
71 |
+
|
72 |
+
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
73 |
+
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
74 |
+
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
75 |
+
|
76 |
+
self.gradient_checkpointing = False
|
77 |
+
|
78 |
+
# Initialize weights and apply final processing
|
79 |
+
self.post_init()
|
80 |
+
|
81 |
+
def get_input_embeddings(self):
|
82 |
+
return self.embed_in
|
83 |
+
|
84 |
+
def set_input_embeddings(self, value):
|
85 |
+
self.embed_in = value
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
input_ids: Optional[torch.LongTensor] = None,
|
90 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
91 |
+
position_ids: Optional[torch.LongTensor] = None,
|
92 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
93 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
94 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
95 |
+
use_cache: Optional[bool] = None,
|
96 |
+
output_attentions: Optional[bool] = None,
|
97 |
+
output_hidden_states: Optional[bool] = None,
|
98 |
+
return_dict: Optional[bool] = None,
|
99 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
100 |
+
r"""
|
101 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
102 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
103 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
104 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
105 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
106 |
+
use_cache (`bool`, *optional*):
|
107 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
108 |
+
`past_key_values`).
|
109 |
+
"""
|
110 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
111 |
+
output_hidden_states = (
|
112 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
113 |
+
)
|
114 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
115 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
116 |
+
|
117 |
+
if input_ids is not None and inputs_embeds is not None:
|
118 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
119 |
+
elif input_ids is not None:
|
120 |
+
input_shape = input_ids.size()
|
121 |
+
elif inputs_embeds is not None:
|
122 |
+
input_shape = inputs_embeds.size()[:-1]
|
123 |
+
else:
|
124 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
125 |
+
|
126 |
+
batch_size, seq_length = input_shape
|
127 |
+
|
128 |
+
if past_key_values is None:
|
129 |
+
past_length = 0
|
130 |
+
past_key_values = tuple([None] * self.config.num_hidden_layers)
|
131 |
+
else:
|
132 |
+
past_length = past_key_values[0][0].size(-2)
|
133 |
+
|
134 |
+
if position_ids is None:
|
135 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
136 |
+
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
|
137 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
138 |
+
else:
|
139 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
140 |
+
|
141 |
+
# Attention mask.
|
142 |
+
if attention_mask is not None:
|
143 |
+
assert batch_size > 0, "batch_size has to be defined and > 0"
|
144 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
145 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
146 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
147 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
148 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
149 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
150 |
+
attention_mask = attention_mask[:, None, None, :]
|
151 |
+
|
152 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
153 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
154 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
155 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
156 |
+
# effectively the same as removing these entirely.
|
157 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
158 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
159 |
+
|
160 |
+
# Prepare head mask if needed
|
161 |
+
# 1.0 in head_mask indicate we keep the head
|
162 |
+
# attention_probs has shape bsz x n_heads x N x N
|
163 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
164 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
165 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
166 |
+
|
167 |
+
if inputs_embeds is None:
|
168 |
+
inputs_embeds = self.embed_in(input_ids)
|
169 |
+
|
170 |
+
hidden_states = inputs_embeds
|
171 |
+
|
172 |
+
if self.gradient_checkpointing and self.training:
|
173 |
+
if use_cache:
|
174 |
+
logger.warning(
|
175 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
176 |
+
)
|
177 |
+
use_cache = False
|
178 |
+
|
179 |
+
presents = () if use_cache else None
|
180 |
+
all_attentions = () if output_attentions else None
|
181 |
+
all_hidden_states = () if output_hidden_states else None
|
182 |
+
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
|
183 |
+
if output_hidden_states:
|
184 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
185 |
+
|
186 |
+
if self.gradient_checkpointing and self.training:
|
187 |
+
|
188 |
+
def create_custom_forward(module):
|
189 |
+
def custom_forward(*inputs):
|
190 |
+
# None for layer_past
|
191 |
+
return module(*inputs, use_cache, None, output_attentions)
|
192 |
+
|
193 |
+
return custom_forward
|
194 |
+
|
195 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
196 |
+
create_custom_forward(layer),
|
197 |
+
hidden_states,
|
198 |
+
attention_mask,
|
199 |
+
position_ids,
|
200 |
+
head_mask[i],
|
201 |
+
)
|
202 |
+
else:
|
203 |
+
outputs = layer(
|
204 |
+
hidden_states,
|
205 |
+
attention_mask=attention_mask,
|
206 |
+
position_ids=position_ids,
|
207 |
+
head_mask=head_mask[i],
|
208 |
+
layer_past=layer_past,
|
209 |
+
use_cache=use_cache,
|
210 |
+
output_attentions=output_attentions,
|
211 |
+
)
|
212 |
+
hidden_states = outputs[0]
|
213 |
+
if use_cache is True:
|
214 |
+
presents = presents + (outputs[1],)
|
215 |
+
if output_attentions:
|
216 |
+
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
|
217 |
+
|
218 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
219 |
+
# Add last hidden state
|
220 |
+
if output_hidden_states:
|
221 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
222 |
+
|
223 |
+
if not return_dict:
|
224 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
225 |
+
|
226 |
+
return BaseModelOutputWithPast(
|
227 |
+
last_hidden_state=hidden_states,
|
228 |
+
past_key_values=presents,
|
229 |
+
hidden_states=all_hidden_states,
|
230 |
+
attentions=all_attentions,
|
231 |
+
)
|
232 |
+
|
233 |
+
|
234 |
+
class DecoderLayer(nn.Module):
|
235 |
+
def __init__(self, config):
|
236 |
+
super().__init__()
|
237 |
+
self.use_parallel_residual = config.use_parallel_residual
|
238 |
+
self.input_layernorm = nn.LayerNorm(
|
239 |
+
config.hidden_size,
|
240 |
+
eps=config.layer_norm_eps,
|
241 |
+
elementwise_affine=False,
|
242 |
+
)
|
243 |
+
self.post_attention_layernorm = nn.LayerNorm(
|
244 |
+
config.hidden_size,
|
245 |
+
eps=config.layer_norm_eps
|
246 |
+
)
|
247 |
+
self.attention = Attention(config)
|
248 |
+
self.mlp = MLP(config)
|
249 |
+
|
250 |
+
def forward(
|
251 |
+
self,
|
252 |
+
hidden_states: Optional[torch.FloatTensor],
|
253 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
254 |
+
position_ids: Optional[torch.LongTensor] = None,
|
255 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
256 |
+
use_cache: Optional[bool] = False,
|
257 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
258 |
+
output_attentions: Optional[bool] = False,
|
259 |
+
):
|
260 |
+
attention_layer_outputs = self.attention(
|
261 |
+
self.input_layernorm(hidden_states),
|
262 |
+
attention_mask=attention_mask,
|
263 |
+
position_ids=position_ids,
|
264 |
+
layer_past=layer_past,
|
265 |
+
head_mask=head_mask,
|
266 |
+
use_cache=use_cache,
|
267 |
+
output_attentions=output_attentions,
|
268 |
+
)
|
269 |
+
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
|
270 |
+
outputs = attention_layer_outputs[1:]
|
271 |
+
|
272 |
+
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
|
273 |
+
hidden_states = hidden_states + mlp_output + attn_output
|
274 |
+
|
275 |
+
if use_cache:
|
276 |
+
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
|
277 |
+
else:
|
278 |
+
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
279 |
+
|
280 |
+
return outputs
|
281 |
+
|
282 |
+
|
283 |
+
class MLP(nn.Module):
|
284 |
+
def __init__(self, config: JapaneseStableLMAlphaConfig):
|
285 |
+
super().__init__()
|
286 |
+
hidden_size = config.hidden_size
|
287 |
+
multiple_of = 256
|
288 |
+
ff_dim = int(8 * hidden_size / 3)
|
289 |
+
intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
|
290 |
+
|
291 |
+
self.packed_input_proj = torch.nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
|
292 |
+
self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
293 |
+
self.act = nn.SiLU()
|
294 |
+
|
295 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
296 |
+
ff, ff_gate = self.packed_input_proj(x).chunk(2, dim=-1)
|
297 |
+
return self.out_proj(ff * self.act(ff_gate))
|
298 |
+
|
299 |
+
|
300 |
+
class RotaryEmbedding(torch.nn.Module):
|
301 |
+
"""Based on Tri Dao's XPos: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/layers/rotary.py"""
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
dim: int,
|
305 |
+
max_position_embeddings: int,
|
306 |
+
base: int = 10_000,
|
307 |
+
scale_base: int = 512,
|
308 |
+
device: str = None
|
309 |
+
):
|
310 |
+
super().__init__()
|
311 |
+
self.dim = dim
|
312 |
+
self.seq_len_cached = max_position_embeddings
|
313 |
+
|
314 |
+
# Set up `inv_freq` term
|
315 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
316 |
+
self.register_buffer("inv_freq", inv_freq)
|
317 |
+
|
318 |
+
# Set up `scale` term
|
319 |
+
self.scale_base = scale_base
|
320 |
+
scale = (
|
321 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
322 |
+
if scale_base is not None else None
|
323 |
+
)
|
324 |
+
self.register_buffer("scale", scale)
|
325 |
+
|
326 |
+
# Seet up `cos..` and `sin...` cache terms
|
327 |
+
t = torch.arange(self.seq_len_cached, device=device, dtype=torch.float32)
|
328 |
+
freqs = torch.outer(t, self.inv_freq)
|
329 |
+
# freqs = torch.cat((freqs, freqs), dim=-1)
|
330 |
+
seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
|
331 |
+
power = (seq_range - self.seq_len_cached // 2) / self.scale_base
|
332 |
+
scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
|
333 |
+
# scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
334 |
+
self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
|
335 |
+
self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
|
336 |
+
self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
|
337 |
+
self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
|
338 |
+
|
339 |
+
def forward(self, x, seq_len=None):
|
340 |
+
if seq_len > self.seq_len_cached:
|
341 |
+
self.seq_len_cached = seq_len
|
342 |
+
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
|
343 |
+
freqs = torch.outer(t, self.inv_freq)
|
344 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
345 |
+
seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device)
|
346 |
+
power = (seq_range - self.seq_len_cached // 2) / self.scale_base
|
347 |
+
scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
|
348 |
+
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
349 |
+
self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
|
350 |
+
self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
|
351 |
+
self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
|
352 |
+
self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
|
353 |
+
return (
|
354 |
+
self.cos_cached[:seq_len, ...],
|
355 |
+
self.sin_cached[:seq_len, ...],
|
356 |
+
self.cos_k_cached[:seq_len, ...],
|
357 |
+
self.sin_k_cached[:seq_len, ...],
|
358 |
+
)
|
359 |
+
|
360 |
+
|
361 |
+
def rotate_half(x):
|
362 |
+
x1, x2 = x.chunk(2, dim=-1)
|
363 |
+
return torch.cat((-x2, x1), dim=-1)
|
364 |
+
|
365 |
+
|
366 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, cos_k=None, sin_k=None):
|
367 |
+
"""
|
368 |
+
q, k: [bs, num_heads, seq_len, rot_dim]
|
369 |
+
cos, sin: [seq_len, rot_dim / 2]
|
370 |
+
position_ids: [bs, seq_len]
|
371 |
+
"""
|
372 |
+
# print(f"q: {q.shape}, k: {k.shape}, cos: {cos.shape}, sin: {sin.shape}, position_ids: {position_ids.shape}")
|
373 |
+
import einops
|
374 |
+
cos = einops.repeat(cos, 's r -> s (2 r)')
|
375 |
+
sin = einops.repeat(sin, 's r -> s (2 r)')
|
376 |
+
cos_k = einops.repeat(cos_k, 's r -> s (2 r)')
|
377 |
+
sin_k = einops.repeat(sin_k, 's r -> s (2 r)')
|
378 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
|
379 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
|
380 |
+
cos_k = cos_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
|
381 |
+
sin_k = sin_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
|
382 |
+
|
383 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
384 |
+
k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
|
385 |
+
return q_embed, k_embed
|
386 |
+
|
387 |
+
|
388 |
+
class Attention(nn.Module):
|
389 |
+
def __init__(self, config):
|
390 |
+
super().__init__()
|
391 |
+
self.num_attention_heads = config.num_attention_heads
|
392 |
+
self.hidden_size = config.hidden_size
|
393 |
+
if self.hidden_size % self.num_attention_heads != 0:
|
394 |
+
raise ValueError(
|
395 |
+
"The hidden size is not divisble by the number of attention heads! Make sure to update them"
|
396 |
+
)
|
397 |
+
self.head_size = self.hidden_size // self.num_attention_heads
|
398 |
+
|
399 |
+
max_positions = config.max_position_embeddings
|
400 |
+
self.register_buffer(
|
401 |
+
"bias",
|
402 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
403 |
+
1, 1, max_positions, max_positions
|
404 |
+
),
|
405 |
+
persistent=False,
|
406 |
+
)
|
407 |
+
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
|
408 |
+
|
409 |
+
self.rotary_ndims = int(self.head_size * config.rotary_pct)
|
410 |
+
self.rotary_emb = RotaryEmbedding(
|
411 |
+
self.rotary_ndims,
|
412 |
+
max_position_embeddings=config.max_position_embeddings,
|
413 |
+
base=config.rotary_emb_base,
|
414 |
+
scale_base=config.rotary_scale_base,
|
415 |
+
)
|
416 |
+
|
417 |
+
self.register_buffer(
|
418 |
+
"norm_factor",
|
419 |
+
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
|
420 |
+
persistent=False,
|
421 |
+
)
|
422 |
+
|
423 |
+
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
|
424 |
+
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
425 |
+
|
426 |
+
def forward(
|
427 |
+
self,
|
428 |
+
hidden_states: torch.FloatTensor,
|
429 |
+
attention_mask: torch.FloatTensor,
|
430 |
+
position_ids: torch.LongTensor,
|
431 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
432 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
433 |
+
use_cache: Optional[bool] = False,
|
434 |
+
output_attentions: Optional[bool] = False,
|
435 |
+
):
|
436 |
+
has_layer_past = layer_past is not None
|
437 |
+
|
438 |
+
# Compute QKV
|
439 |
+
# Attention heads [batch, seq_len, hidden_size]
|
440 |
+
# --> [batch, seq_len, (np * 3 * head_size)]
|
441 |
+
qkv = self.query_key_value(hidden_states)
|
442 |
+
|
443 |
+
# [batch, seq_len, (num_heads * 3 * head_size)]
|
444 |
+
# --> [batch, seq_len, num_heads, 3 * head_size]
|
445 |
+
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
|
446 |
+
qkv = qkv.view(*new_qkv_shape)
|
447 |
+
|
448 |
+
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
|
449 |
+
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
|
450 |
+
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
|
451 |
+
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
|
452 |
+
|
453 |
+
# Compute rotary embeddings on rotary_ndims
|
454 |
+
query_rot = query[..., : self.rotary_ndims]
|
455 |
+
query_pass = query[..., self.rotary_ndims :]
|
456 |
+
key_rot = key[..., : self.rotary_ndims]
|
457 |
+
key_pass = key[..., self.rotary_ndims :]
|
458 |
+
|
459 |
+
# Compute token offset for rotary embeddings (when decoding)
|
460 |
+
kv_seq_len = key.shape[-2]
|
461 |
+
if has_layer_past:
|
462 |
+
kv_seq_len += layer_past[0].shape[-2]
|
463 |
+
|
464 |
+
# Add rotary embeddings to query and key
|
465 |
+
# TODO: Check if using xpos
|
466 |
+
cos, sin, cos_k, sin_k = self.rotary_emb(value, seq_len=kv_seq_len)
|
467 |
+
query, key = apply_rotary_pos_emb(
|
468 |
+
query_rot, key_rot, cos, sin, position_ids, cos_k=cos_k, sin_k=sin_k)
|
469 |
+
|
470 |
+
query = torch.cat((query, query_pass), dim=-1)
|
471 |
+
key = torch.cat((key, key_pass), dim=-1)
|
472 |
+
|
473 |
+
# Cache QKV values
|
474 |
+
if has_layer_past:
|
475 |
+
past_key = layer_past[0]
|
476 |
+
past_value = layer_past[1]
|
477 |
+
key = torch.cat((past_key, key), dim=-2)
|
478 |
+
value = torch.cat((past_value, value), dim=-2)
|
479 |
+
present = (key, value) if use_cache else None
|
480 |
+
|
481 |
+
# Compute attention
|
482 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
483 |
+
|
484 |
+
# Merge attn_head_size dim and num_attn_heads dim into hidden dim
|
485 |
+
# [bs, seq_len, num_attention_heads, attn_head_size]
|
486 |
+
attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
|
487 |
+
attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), self.num_attention_heads * self.head_size)
|
488 |
+
|
489 |
+
attn_output = self.dense(attn_output)
|
490 |
+
|
491 |
+
outputs = (attn_output, present)
|
492 |
+
if output_attentions:
|
493 |
+
outputs += (attn_weights,)
|
494 |
+
|
495 |
+
return outputs
|
496 |
+
|
497 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
498 |
+
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
|
499 |
+
# compute causal mask from causal mask buffer
|
500 |
+
|
501 |
+
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
|
502 |
+
key_length = key.size(-2)
|
503 |
+
|
504 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
505 |
+
|
506 |
+
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
|
507 |
+
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
|
508 |
+
attn_scores = torch.zeros(
|
509 |
+
batch_size * num_attention_heads,
|
510 |
+
query_length,
|
511 |
+
key_length,
|
512 |
+
dtype=query.dtype,
|
513 |
+
device=key.device,
|
514 |
+
)
|
515 |
+
attn_scores = torch.baddbmm(
|
516 |
+
attn_scores,
|
517 |
+
query,
|
518 |
+
key.transpose(1, 2),
|
519 |
+
beta=1.0,
|
520 |
+
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
|
521 |
+
)
|
522 |
+
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
|
523 |
+
|
524 |
+
mask_value = torch.finfo(attn_scores.dtype).min
|
525 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
526 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
527 |
+
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype, device=attn_scores.device)
|
528 |
+
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
|
529 |
+
|
530 |
+
if attention_mask is not None:
|
531 |
+
# Apply the attention mask
|
532 |
+
attn_scores = attn_scores + attention_mask
|
533 |
+
|
534 |
+
# NOTE: Upcast to float32
|
535 |
+
attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).type_as(value)
|
536 |
+
|
537 |
+
# Mask heads if we want to
|
538 |
+
if head_mask is not None:
|
539 |
+
attn_weights = attn_weights * head_mask
|
540 |
+
|
541 |
+
attn_output = torch.matmul(attn_weights, value)
|
542 |
+
return attn_output, attn_weights
|
543 |
+
|
544 |
+
|
545 |
+
def attention_mask_func(attention_scores, ltor_mask):
|
546 |
+
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
|
547 |
+
return attention_scores
|
548 |
+
|
549 |
+
|
550 |
+
class JapaneseStableLMAlphaForCausalLM(JapaneseStableLMAlphaPreTrainedModel):
|
551 |
+
_tied_weights_keys = ["embed_out.weight"]
|
552 |
+
|
553 |
+
def __init__(self, config):
|
554 |
+
super().__init__(config)
|
555 |
+
|
556 |
+
self.transformer = JapaneseStableLMAlphaModel(config)
|
557 |
+
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
558 |
+
|
559 |
+
# Initialize weights and apply final processing
|
560 |
+
self.post_init()
|
561 |
+
|
562 |
+
def get_output_embeddings(self):
|
563 |
+
return self.embed_out
|
564 |
+
|
565 |
+
def set_output_embeddings(self, new_embeddings):
|
566 |
+
self.embed_out = new_embeddings
|
567 |
+
|
568 |
+
def forward(
|
569 |
+
self,
|
570 |
+
input_ids: Optional[torch.LongTensor] = None,
|
571 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
572 |
+
position_ids: Optional[torch.LongTensor] = None,
|
573 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
574 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
575 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
576 |
+
labels: Optional[torch.LongTensor] = None,
|
577 |
+
use_cache: Optional[bool] = None,
|
578 |
+
output_attentions: Optional[bool] = None,
|
579 |
+
output_hidden_states: Optional[bool] = None,
|
580 |
+
return_dict: Optional[bool] = None,
|
581 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
582 |
+
r"""
|
583 |
+
Example:
|
584 |
+
|
585 |
+
```python
|
586 |
+
>>> import torch
|
587 |
+
>>> from transformers import LlamaTokenizer, JapaneseStableLMAlphaForCausalLM, JapaneseStableLMAlphaConfig
|
588 |
+
|
589 |
+
>>> tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1")
|
590 |
+
>>> config = JapaneseStableLMAlphaConfig.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b")
|
591 |
+
>>> config.is_decoder = True
|
592 |
+
>>> model = JapaneseStableLMAlphaForCausalLM.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b", config=config, trust_remote_code=True)
|
593 |
+
|
594 |
+
>>> inputs = tokenizer("日本語の美しいところは、", return_tensors="pt")
|
595 |
+
>>> outputs = model(**inputs)
|
596 |
+
|
597 |
+
>>> prediction_logits = outputs.logits
|
598 |
+
```"""
|
599 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
600 |
+
|
601 |
+
outputs = self.transformer(
|
602 |
+
input_ids,
|
603 |
+
attention_mask=attention_mask,
|
604 |
+
position_ids=position_ids,
|
605 |
+
head_mask=head_mask,
|
606 |
+
inputs_embeds=inputs_embeds,
|
607 |
+
past_key_values=past_key_values,
|
608 |
+
use_cache=use_cache,
|
609 |
+
output_attentions=output_attentions,
|
610 |
+
output_hidden_states=output_hidden_states,
|
611 |
+
return_dict=return_dict,
|
612 |
+
)
|
613 |
+
|
614 |
+
hidden_states = outputs[0]
|
615 |
+
lm_logits = self.embed_out(hidden_states)
|
616 |
+
|
617 |
+
lm_loss = None
|
618 |
+
if labels is not None:
|
619 |
+
# move labels to correct device to enable model parallelism
|
620 |
+
labels = labels.to(lm_logits.device)
|
621 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
622 |
+
shift_logits = lm_logits[:, :-1, :].contiguous()
|
623 |
+
labels = labels[:, 1:].contiguous()
|
624 |
+
loss_fct = CrossEntropyLoss()
|
625 |
+
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
|
626 |
+
|
627 |
+
if not return_dict:
|
628 |
+
output = (lm_logits,) + outputs[1:]
|
629 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
630 |
+
|
631 |
+
return CausalLMOutputWithPast(
|
632 |
+
loss=lm_loss,
|
633 |
+
logits=lm_logits,
|
634 |
+
past_key_values=outputs.past_key_values,
|
635 |
+
hidden_states=outputs.hidden_states,
|
636 |
+
attentions=outputs.attentions,
|
637 |
+
)
|
638 |
+
|
639 |
+
def prepare_inputs_for_generation(
|
640 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
641 |
+
):
|
642 |
+
input_shape = input_ids.shape
|
643 |
+
|
644 |
+
# cut decoder_input_ids if past is used
|
645 |
+
if past_key_values and past_key_values[0] is not None:
|
646 |
+
input_ids = input_ids[:, -1:]
|
647 |
+
|
648 |
+
position_ids = kwargs.get("position_ids", None)
|
649 |
+
if attention_mask is not None and position_ids is None:
|
650 |
+
# create position_ids on the fly for batch generation
|
651 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
652 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
653 |
+
if past_key_values:
|
654 |
+
position_ids = position_ids[:, -1].unsqueeze(-1)
|
655 |
+
|
656 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
657 |
+
if attention_mask is None:
|
658 |
+
attention_mask = input_ids.new_ones(input_shape)
|
659 |
+
|
660 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
661 |
+
if inputs_embeds is not None and past_key_values is None:
|
662 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
663 |
+
else:
|
664 |
+
model_inputs = {"input_ids": input_ids}
|
665 |
+
|
666 |
+
model_inputs.update(
|
667 |
+
{
|
668 |
+
"attention_mask": attention_mask,
|
669 |
+
"past_key_values": past_key_values,
|
670 |
+
"position_ids": position_ids,
|
671 |
+
}
|
672 |
+
)
|
673 |
+
|
674 |
+
return model_inputs
|
675 |
+
|
676 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
677 |
+
reordered_past = ()
|
678 |
+
for layer_past in past_key_values:
|
679 |
+
reordered_past += (
|
680 |
+
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
|
681 |
+
)
|
682 |
+
return reordered_past
|
pytorch_model-00001-of-00003.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:299e36821f331d32271e0784107e004682e5acba667ba93276c54e23567922a0
|
3 |
+
size 9978676569
|
pytorch_model-00002-of-00003.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:335ad506d6255a4bb7629e231196b40581d870fcaf3ac7e9c7c39d54cd160770
|
3 |
+
size 9982872727
|
pytorch_model-00003-of-00003.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d831fc7e8001819e69823c4590128a4059ba3c2c7df79fbeee11a91f606b149
|
3 |
+
size 8091132329
|
pytorch_model.bin.index.json
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 28052590592
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"embed_out.weight": "pytorch_model-00003-of-00003.bin",
|
7 |
+
"transformer.embed_in.weight": "pytorch_model-00001-of-00003.bin",
|
8 |
+
"transformer.final_layer_norm.bias": "pytorch_model-00003-of-00003.bin",
|
9 |
+
"transformer.final_layer_norm.weight": "pytorch_model-00003-of-00003.bin",
|
10 |
+
"transformer.layers.0.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
11 |
+
"transformer.layers.0.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
12 |
+
"transformer.layers.0.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
13 |
+
"transformer.layers.0.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
14 |
+
"transformer.layers.0.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
15 |
+
"transformer.layers.0.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
16 |
+
"transformer.layers.0.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
17 |
+
"transformer.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
18 |
+
"transformer.layers.1.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
19 |
+
"transformer.layers.1.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
20 |
+
"transformer.layers.1.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
21 |
+
"transformer.layers.1.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
22 |
+
"transformer.layers.1.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
23 |
+
"transformer.layers.1.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
24 |
+
"transformer.layers.1.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
25 |
+
"transformer.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
26 |
+
"transformer.layers.10.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
27 |
+
"transformer.layers.10.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
28 |
+
"transformer.layers.10.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
29 |
+
"transformer.layers.10.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
30 |
+
"transformer.layers.10.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
31 |
+
"transformer.layers.10.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
32 |
+
"transformer.layers.10.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
33 |
+
"transformer.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
34 |
+
"transformer.layers.11.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
35 |
+
"transformer.layers.11.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
36 |
+
"transformer.layers.11.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
37 |
+
"transformer.layers.11.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
38 |
+
"transformer.layers.11.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
39 |
+
"transformer.layers.11.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
40 |
+
"transformer.layers.11.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
41 |
+
"transformer.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
42 |
+
"transformer.layers.12.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
43 |
+
"transformer.layers.12.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
44 |
+
"transformer.layers.12.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
45 |
+
"transformer.layers.12.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
46 |
+
"transformer.layers.12.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
47 |
+
"transformer.layers.12.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
48 |
+
"transformer.layers.12.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
49 |
+
"transformer.layers.12.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
50 |
+
"transformer.layers.13.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
51 |
+
"transformer.layers.13.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
52 |
+
"transformer.layers.13.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
53 |
+
"transformer.layers.13.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
54 |
+
"transformer.layers.13.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
55 |
+
"transformer.layers.13.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
56 |
+
"transformer.layers.13.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
57 |
+
"transformer.layers.13.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
58 |
+
"transformer.layers.14.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
59 |
+
"transformer.layers.14.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
60 |
+
"transformer.layers.14.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
61 |
+
"transformer.layers.14.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
62 |
+
"transformer.layers.14.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
63 |
+
"transformer.layers.14.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
64 |
+
"transformer.layers.14.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
65 |
+
"transformer.layers.14.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
66 |
+
"transformer.layers.15.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
67 |
+
"transformer.layers.15.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
68 |
+
"transformer.layers.15.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
69 |
+
"transformer.layers.15.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
70 |
+
"transformer.layers.15.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
71 |
+
"transformer.layers.15.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
72 |
+
"transformer.layers.15.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
73 |
+
"transformer.layers.15.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
74 |
+
"transformer.layers.16.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
75 |
+
"transformer.layers.16.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
76 |
+
"transformer.layers.16.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
77 |
+
"transformer.layers.16.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
78 |
+
"transformer.layers.16.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
79 |
+
"transformer.layers.16.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
80 |
+
"transformer.layers.16.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
81 |
+
"transformer.layers.16.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
82 |
+
"transformer.layers.17.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
83 |
+
"transformer.layers.17.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
84 |
+
"transformer.layers.17.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
85 |
+
"transformer.layers.17.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
86 |
+
"transformer.layers.17.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
87 |
+
"transformer.layers.17.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
88 |
+
"transformer.layers.17.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
89 |
+
"transformer.layers.17.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
90 |
+
"transformer.layers.18.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
91 |
+
"transformer.layers.18.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
92 |
+
"transformer.layers.18.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
93 |
+
"transformer.layers.18.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
94 |
+
"transformer.layers.18.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
95 |
+
"transformer.layers.18.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
96 |
+
"transformer.layers.18.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
97 |
+
"transformer.layers.18.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
98 |
+
"transformer.layers.19.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
99 |
+
"transformer.layers.19.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
100 |
+
"transformer.layers.19.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
101 |
+
"transformer.layers.19.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
102 |
+
"transformer.layers.19.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
103 |
+
"transformer.layers.19.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
104 |
+
"transformer.layers.19.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
105 |
+
"transformer.layers.19.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
106 |
+
"transformer.layers.2.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
107 |
+
"transformer.layers.2.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
108 |
+
"transformer.layers.2.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
109 |
+
"transformer.layers.2.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
110 |
+
"transformer.layers.2.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
111 |
+
"transformer.layers.2.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
112 |
+
"transformer.layers.2.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
113 |
+
"transformer.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
114 |
+
"transformer.layers.20.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
115 |
+
"transformer.layers.20.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
116 |
+
"transformer.layers.20.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
117 |
+
"transformer.layers.20.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
118 |
+
"transformer.layers.20.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
119 |
+
"transformer.layers.20.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
120 |
+
"transformer.layers.20.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
121 |
+
"transformer.layers.20.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
122 |
+
"transformer.layers.21.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
123 |
+
"transformer.layers.21.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
124 |
+
"transformer.layers.21.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
125 |
+
"transformer.layers.21.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
126 |
+
"transformer.layers.21.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
127 |
+
"transformer.layers.21.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
128 |
+
"transformer.layers.21.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
129 |
+
"transformer.layers.21.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
130 |
+
"transformer.layers.22.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
131 |
+
"transformer.layers.22.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
132 |
+
"transformer.layers.22.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
133 |
+
"transformer.layers.22.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
134 |
+
"transformer.layers.22.mlp.out_proj.weight": "pytorch_model-00002-of-00003.bin",
|
135 |
+
"transformer.layers.22.mlp.packed_input_proj.weight": "pytorch_model-00002-of-00003.bin",
|
136 |
+
"transformer.layers.22.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
137 |
+
"transformer.layers.22.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
138 |
+
"transformer.layers.23.attention.dense.weight": "pytorch_model-00002-of-00003.bin",
|
139 |
+
"transformer.layers.23.attention.query_key_value.weight": "pytorch_model-00002-of-00003.bin",
|
140 |
+
"transformer.layers.23.attention.rotary_emb.inv_freq": "pytorch_model-00002-of-00003.bin",
|
141 |
+
"transformer.layers.23.attention.rotary_emb.scale": "pytorch_model-00002-of-00003.bin",
|
142 |
+
"transformer.layers.23.mlp.out_proj.weight": "pytorch_model-00003-of-00003.bin",
|
143 |
+
"transformer.layers.23.mlp.packed_input_proj.weight": "pytorch_model-00003-of-00003.bin",
|
144 |
+
"transformer.layers.23.post_attention_layernorm.bias": "pytorch_model-00002-of-00003.bin",
|
145 |
+
"transformer.layers.23.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin",
|
146 |
+
"transformer.layers.24.attention.dense.weight": "pytorch_model-00003-of-00003.bin",
|
147 |
+
"transformer.layers.24.attention.query_key_value.weight": "pytorch_model-00003-of-00003.bin",
|
148 |
+
"transformer.layers.24.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00003.bin",
|
149 |
+
"transformer.layers.24.attention.rotary_emb.scale": "pytorch_model-00003-of-00003.bin",
|
150 |
+
"transformer.layers.24.mlp.out_proj.weight": "pytorch_model-00003-of-00003.bin",
|
151 |
+
"transformer.layers.24.mlp.packed_input_proj.weight": "pytorch_model-00003-of-00003.bin",
|
152 |
+
"transformer.layers.24.post_attention_layernorm.bias": "pytorch_model-00003-of-00003.bin",
|
153 |
+
"transformer.layers.24.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
|
154 |
+
"transformer.layers.25.attention.dense.weight": "pytorch_model-00003-of-00003.bin",
|
155 |
+
"transformer.layers.25.attention.query_key_value.weight": "pytorch_model-00003-of-00003.bin",
|
156 |
+
"transformer.layers.25.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00003.bin",
|
157 |
+
"transformer.layers.25.attention.rotary_emb.scale": "pytorch_model-00003-of-00003.bin",
|
158 |
+
"transformer.layers.25.mlp.out_proj.weight": "pytorch_model-00003-of-00003.bin",
|
159 |
+
"transformer.layers.25.mlp.packed_input_proj.weight": "pytorch_model-00003-of-00003.bin",
|
160 |
+
"transformer.layers.25.post_attention_layernorm.bias": "pytorch_model-00003-of-00003.bin",
|
161 |
+
"transformer.layers.25.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
|
162 |
+
"transformer.layers.26.attention.dense.weight": "pytorch_model-00003-of-00003.bin",
|
163 |
+
"transformer.layers.26.attention.query_key_value.weight": "pytorch_model-00003-of-00003.bin",
|
164 |
+
"transformer.layers.26.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00003.bin",
|
165 |
+
"transformer.layers.26.attention.rotary_emb.scale": "pytorch_model-00003-of-00003.bin",
|
166 |
+
"transformer.layers.26.mlp.out_proj.weight": "pytorch_model-00003-of-00003.bin",
|
167 |
+
"transformer.layers.26.mlp.packed_input_proj.weight": "pytorch_model-00003-of-00003.bin",
|
168 |
+
"transformer.layers.26.post_attention_layernorm.bias": "pytorch_model-00003-of-00003.bin",
|
169 |
+
"transformer.layers.26.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
|
170 |
+
"transformer.layers.27.attention.dense.weight": "pytorch_model-00003-of-00003.bin",
|
171 |
+
"transformer.layers.27.attention.query_key_value.weight": "pytorch_model-00003-of-00003.bin",
|
172 |
+
"transformer.layers.27.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00003.bin",
|
173 |
+
"transformer.layers.27.attention.rotary_emb.scale": "pytorch_model-00003-of-00003.bin",
|
174 |
+
"transformer.layers.27.mlp.out_proj.weight": "pytorch_model-00003-of-00003.bin",
|
175 |
+
"transformer.layers.27.mlp.packed_input_proj.weight": "pytorch_model-00003-of-00003.bin",
|
176 |
+
"transformer.layers.27.post_attention_layernorm.bias": "pytorch_model-00003-of-00003.bin",
|
177 |
+
"transformer.layers.27.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
|
178 |
+
"transformer.layers.28.attention.dense.weight": "pytorch_model-00003-of-00003.bin",
|
179 |
+
"transformer.layers.28.attention.query_key_value.weight": "pytorch_model-00003-of-00003.bin",
|
180 |
+
"transformer.layers.28.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00003.bin",
|
181 |
+
"transformer.layers.28.attention.rotary_emb.scale": "pytorch_model-00003-of-00003.bin",
|
182 |
+
"transformer.layers.28.mlp.out_proj.weight": "pytorch_model-00003-of-00003.bin",
|
183 |
+
"transformer.layers.28.mlp.packed_input_proj.weight": "pytorch_model-00003-of-00003.bin",
|
184 |
+
"transformer.layers.28.post_attention_layernorm.bias": "pytorch_model-00003-of-00003.bin",
|
185 |
+
"transformer.layers.28.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
|
186 |
+
"transformer.layers.29.attention.dense.weight": "pytorch_model-00003-of-00003.bin",
|
187 |
+
"transformer.layers.29.attention.query_key_value.weight": "pytorch_model-00003-of-00003.bin",
|
188 |
+
"transformer.layers.29.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00003.bin",
|
189 |
+
"transformer.layers.29.attention.rotary_emb.scale": "pytorch_model-00003-of-00003.bin",
|
190 |
+
"transformer.layers.29.mlp.out_proj.weight": "pytorch_model-00003-of-00003.bin",
|
191 |
+
"transformer.layers.29.mlp.packed_input_proj.weight": "pytorch_model-00003-of-00003.bin",
|
192 |
+
"transformer.layers.29.post_attention_layernorm.bias": "pytorch_model-00003-of-00003.bin",
|
193 |
+
"transformer.layers.29.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
|
194 |
+
"transformer.layers.3.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
195 |
+
"transformer.layers.3.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
196 |
+
"transformer.layers.3.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
197 |
+
"transformer.layers.3.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
198 |
+
"transformer.layers.3.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
199 |
+
"transformer.layers.3.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
200 |
+
"transformer.layers.3.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
201 |
+
"transformer.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
202 |
+
"transformer.layers.30.attention.dense.weight": "pytorch_model-00003-of-00003.bin",
|
203 |
+
"transformer.layers.30.attention.query_key_value.weight": "pytorch_model-00003-of-00003.bin",
|
204 |
+
"transformer.layers.30.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00003.bin",
|
205 |
+
"transformer.layers.30.attention.rotary_emb.scale": "pytorch_model-00003-of-00003.bin",
|
206 |
+
"transformer.layers.30.mlp.out_proj.weight": "pytorch_model-00003-of-00003.bin",
|
207 |
+
"transformer.layers.30.mlp.packed_input_proj.weight": "pytorch_model-00003-of-00003.bin",
|
208 |
+
"transformer.layers.30.post_attention_layernorm.bias": "pytorch_model-00003-of-00003.bin",
|
209 |
+
"transformer.layers.30.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
|
210 |
+
"transformer.layers.31.attention.dense.weight": "pytorch_model-00003-of-00003.bin",
|
211 |
+
"transformer.layers.31.attention.query_key_value.weight": "pytorch_model-00003-of-00003.bin",
|
212 |
+
"transformer.layers.31.attention.rotary_emb.inv_freq": "pytorch_model-00003-of-00003.bin",
|
213 |
+
"transformer.layers.31.attention.rotary_emb.scale": "pytorch_model-00003-of-00003.bin",
|
214 |
+
"transformer.layers.31.mlp.out_proj.weight": "pytorch_model-00003-of-00003.bin",
|
215 |
+
"transformer.layers.31.mlp.packed_input_proj.weight": "pytorch_model-00003-of-00003.bin",
|
216 |
+
"transformer.layers.31.post_attention_layernorm.bias": "pytorch_model-00003-of-00003.bin",
|
217 |
+
"transformer.layers.31.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin",
|
218 |
+
"transformer.layers.4.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
219 |
+
"transformer.layers.4.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
220 |
+
"transformer.layers.4.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
221 |
+
"transformer.layers.4.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
222 |
+
"transformer.layers.4.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
223 |
+
"transformer.layers.4.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
224 |
+
"transformer.layers.4.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
225 |
+
"transformer.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
226 |
+
"transformer.layers.5.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
227 |
+
"transformer.layers.5.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
228 |
+
"transformer.layers.5.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
229 |
+
"transformer.layers.5.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
230 |
+
"transformer.layers.5.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
231 |
+
"transformer.layers.5.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
232 |
+
"transformer.layers.5.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
233 |
+
"transformer.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
234 |
+
"transformer.layers.6.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
235 |
+
"transformer.layers.6.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
236 |
+
"transformer.layers.6.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
237 |
+
"transformer.layers.6.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
238 |
+
"transformer.layers.6.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
239 |
+
"transformer.layers.6.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
240 |
+
"transformer.layers.6.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
241 |
+
"transformer.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
242 |
+
"transformer.layers.7.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
243 |
+
"transformer.layers.7.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
244 |
+
"transformer.layers.7.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
245 |
+
"transformer.layers.7.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
246 |
+
"transformer.layers.7.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
247 |
+
"transformer.layers.7.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
248 |
+
"transformer.layers.7.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
249 |
+
"transformer.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
250 |
+
"transformer.layers.8.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
251 |
+
"transformer.layers.8.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
252 |
+
"transformer.layers.8.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
253 |
+
"transformer.layers.8.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
254 |
+
"transformer.layers.8.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
255 |
+
"transformer.layers.8.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
256 |
+
"transformer.layers.8.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
257 |
+
"transformer.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin",
|
258 |
+
"transformer.layers.9.attention.dense.weight": "pytorch_model-00001-of-00003.bin",
|
259 |
+
"transformer.layers.9.attention.query_key_value.weight": "pytorch_model-00001-of-00003.bin",
|
260 |
+
"transformer.layers.9.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00003.bin",
|
261 |
+
"transformer.layers.9.attention.rotary_emb.scale": "pytorch_model-00001-of-00003.bin",
|
262 |
+
"transformer.layers.9.mlp.out_proj.weight": "pytorch_model-00001-of-00003.bin",
|
263 |
+
"transformer.layers.9.mlp.packed_input_proj.weight": "pytorch_model-00001-of-00003.bin",
|
264 |
+
"transformer.layers.9.post_attention_layernorm.bias": "pytorch_model-00001-of-00003.bin",
|
265 |
+
"transformer.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin"
|
266 |
+
}
|
267 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
sentencepiece
|
2 |
+
einops
|