Upload 13 files
Browse files- __init__.py +0 -0
- config.json +30 -0
- configuration_rwkv7.py +129 -0
- cuda/state_wkv7_cuda.cu +152 -0
- cuda/state_wkv7_op.cpp +34 -0
- cuda/wkv7_cuda.cu +138 -0
- cuda/wkv7_op.cpp +34 -0
- modeling_blocks_rwkv7.py +0 -0
- modeling_rwkv7.py +460 -0
- special_tokens_map.json +1 -0
- tokenizer.json +0 -0
- tokenizer_config.json +1 -0
- vocab.json +0 -0
__init__.py
ADDED
File without changes
|
config.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"RWKV7ForCasualLM",
|
4 |
+
"RWKV7Model",
|
5 |
+
"RWKV7PreTrainedModel"
|
6 |
+
],
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"device": null,
|
9 |
+
"dropout_rate": 0.0,
|
10 |
+
"dtype": null,
|
11 |
+
"eos_token_id": 0,
|
12 |
+
"head_size": 64,
|
13 |
+
"hidden_size": 2048,
|
14 |
+
"hidden_size_att": 2048,
|
15 |
+
"hidden_size_ffn": 8192,
|
16 |
+
"init_state_wkv": false,
|
17 |
+
"layer_id": null,
|
18 |
+
"model_type": "rwkv7",
|
19 |
+
"num_hidden_layers": 24,
|
20 |
+
"tie_word_embeddings": false,
|
21 |
+
"tmix_backend": "auto",
|
22 |
+
"torch_dtype": "bfloat16",
|
23 |
+
"transformers_version": "4.48.0",
|
24 |
+
"vocab_size": 50304,
|
25 |
+
"auto_map": {
|
26 |
+
"AutoConfig": "configuration_rwkv7.RWKV7Config",
|
27 |
+
"AutoModel": "modeling_rwkv7.RWKV7Model",
|
28 |
+
"AutoModelForCasualLM": "modeling_rwkv7.RWKV7ForCasualLM"
|
29 |
+
}
|
30 |
+
}
|
configuration_rwkv7.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" RWKV configuration"""
|
2 |
+
|
3 |
+
from transformers.configuration_utils import PretrainedConfig
|
4 |
+
# from transformers.utils import logging
|
5 |
+
# logger = logging.get_logger(__name__)
|
6 |
+
|
7 |
+
# Import the dependencies
|
8 |
+
from .modeling_blocks_rwkv7 import RWKV7GooseConfigMap
|
9 |
+
|
10 |
+
class RWKV7Config(PretrainedConfig):
|
11 |
+
"""
|
12 |
+
This is the configuration class to store the configuration of a [`Rwkv7Model`]. It is used to instantiate a RWKV7
|
13 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
14 |
+
defaults will yield a similar configuration to that of the RWVK-7
|
15 |
+
[RWKV/v7-Goose-1.6B-Pile-HF](https://huggingface.co/RWKV/v7-Goose-1.6B-Pile-HF) architecture.
|
16 |
+
|
17 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
18 |
+
documentation from [`PretrainedConfig`] for more information.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
vocab_size (`int`, *optional*, defaults to 65536):
|
22 |
+
Vocabulary size of the RWKV7 model. Defines the number of different tokens that can be represented by the
|
23 |
+
`inputs_ids` passed when calling [`Rwkv7Model`].
|
24 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
25 |
+
Number of hidden layers in the model.
|
26 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
27 |
+
Dimensionality of the embeddings and hidden states.
|
28 |
+
|
29 |
+
hidden_size_att (`int`, *optional*):
|
30 |
+
Dimensionality of the attention hidden states. Will be computed from `hidden_size` if unset.
|
31 |
+
hidden_size_ffn (`int`, *optional*):
|
32 |
+
Dimensionality of the FFN hidden states. Will be computed from `hidden_size` if unset.
|
33 |
+
head_size (`int`, *optional*, defaults to 64):
|
34 |
+
head_size of rwkv7 self_attention module.
|
35 |
+
tmix_backend (`str`, *optional*, defaults to "auto"):
|
36 |
+
Backend to use for the time mix module. "auto" defaults to "pytorch" if the device is "cpu" and "cuda" otherwise.
|
37 |
+
(Valid values: "auto", "pytorch", "cuda", "triton", "triton_bighead", "fla", "fla_fused", "pytorch_ref", "pytorch_ref_fp32")
|
38 |
+
init_state_wkv (`bool`, *optional*, defaults to `False`):
|
39 |
+
Whether to initialize the wkv state in the model. Used for WKV state tuning.
|
40 |
+
|
41 |
+
device (`str`, *optional*):
|
42 |
+
Device to use for the model. Use the respective torch.device types
|
43 |
+
dtype (`str`, *optional*):
|
44 |
+
Model weights data type. Use the respective torch.dtype types
|
45 |
+
|
46 |
+
bos_token_id (`int`, *optional*, defaults to 0):
|
47 |
+
The id of the beginning of sentence token in the vocabulary. Defaults to 0.
|
48 |
+
eos_token_id (`int`, *optional*, defaults to 0):
|
49 |
+
The id of the end of sentence token in the vocabulary. Defaults to 0.
|
50 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
51 |
+
Whether or not to tie the word embeddings with the input token embeddings.
|
52 |
+
(this value is currently ignored in our implementation)
|
53 |
+
|
54 |
+
Example:
|
55 |
+
|
56 |
+
```python
|
57 |
+
>>> from transformers import Rwkv7Config, Rwkv7Model
|
58 |
+
|
59 |
+
>>> # Initializing a Rwkv7 configuration
|
60 |
+
>>> configuration = Rwkv7Config()
|
61 |
+
|
62 |
+
>>> # Initializing a model (with random weights) from the configuration
|
63 |
+
>>> model = Rwkv7Model(configuration)
|
64 |
+
|
65 |
+
>>> # Accessing the model configuration
|
66 |
+
>>> configuration = model.config
|
67 |
+
```"""
|
68 |
+
|
69 |
+
model_type = "rwkv7"
|
70 |
+
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
########################################
|
74 |
+
# Vocab, layer count, and hidden size
|
75 |
+
vocab_size=65536,
|
76 |
+
num_hidden_layers=24,
|
77 |
+
hidden_size=768,
|
78 |
+
# Optional hidden sizes
|
79 |
+
hidden_size_att=None,
|
80 |
+
hidden_size_ffn=None,
|
81 |
+
# Headsize, timemix backend
|
82 |
+
head_size=64,
|
83 |
+
tmix_backend="auto",
|
84 |
+
init_state_wkv=False,
|
85 |
+
# Trainer model configs
|
86 |
+
dropout_rate=0.0,
|
87 |
+
# Torch device and dtype
|
88 |
+
device=None,
|
89 |
+
dtype=None,
|
90 |
+
# Tokenizer related settings in HF configuration
|
91 |
+
bos_token_id=0,
|
92 |
+
eos_token_id=0,
|
93 |
+
tie_word_embeddings=False,
|
94 |
+
########################################
|
95 |
+
**kwargs,
|
96 |
+
):
|
97 |
+
# Normalize dtype if torch_dtype is set within kwargs
|
98 |
+
if dtype is None and "torch_dtype" in kwargs:
|
99 |
+
dtype = kwargs["torch_dtype"]
|
100 |
+
|
101 |
+
self.vocab_size = vocab_size
|
102 |
+
self.num_hidden_layers = num_hidden_layers
|
103 |
+
self.hidden_size = hidden_size
|
104 |
+
self.hidden_size_att = hidden_size_att
|
105 |
+
self.hidden_size_ffn = hidden_size_ffn
|
106 |
+
|
107 |
+
self.head_size = head_size
|
108 |
+
self.tmix_backend = tmix_backend
|
109 |
+
self.init_state_wkv = init_state_wkv
|
110 |
+
|
111 |
+
self.device = device
|
112 |
+
self.dtype = dtype
|
113 |
+
|
114 |
+
self.dropout_rate = dropout_rate
|
115 |
+
|
116 |
+
# Forward to the HF PretrainedConfig
|
117 |
+
super().__init__(
|
118 |
+
tie_word_embeddings=tie_word_embeddings,
|
119 |
+
bos_token_id=bos_token_id,
|
120 |
+
eos_token_id=eos_token_id,
|
121 |
+
**kwargs
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def from_model_state_dict(state_dict: dict, **kwargs):
|
127 |
+
goose_config = RWKV7GooseConfigMap.from_model_state_dict(state_dict)
|
128 |
+
# Join dictionary with **goose_config.__dict__ and **kwargs
|
129 |
+
return RWKV7Config(**{**goose_config.__dict__, **kwargs})
|
cuda/state_wkv7_cuda.cu
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda_bf16.h>
|
2 |
+
#include <assert.h>
|
3 |
+
|
4 |
+
using bf = __nv_bfloat16;
|
5 |
+
__device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
|
6 |
+
__device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); }
|
7 |
+
|
8 |
+
typedef bf * __restrict__ F_;
|
9 |
+
|
10 |
+
__global__ void forward_kernel(int T, int H, float*_state, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) {
|
11 |
+
constexpr int C = _C_;
|
12 |
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
13 |
+
|
14 |
+
float state[C] = {0};
|
15 |
+
int s_idx = bb*H*C*C + hh*C*C + i*C;
|
16 |
+
#pragma unroll
|
17 |
+
for (int j = 0; j < C; j++) {
|
18 |
+
state[j] = _state[s_idx+j];
|
19 |
+
}
|
20 |
+
|
21 |
+
__shared__ float q[C], k[C], w[C], a[C], b[C];
|
22 |
+
|
23 |
+
for (int t = 0; t < T; t++) {
|
24 |
+
int ind = bb*T*H*C + t*H*C + hh * C + i;
|
25 |
+
__syncthreads();
|
26 |
+
q[i] = to_float(q_[ind]);
|
27 |
+
w[i] = __expf(-__expf(to_float(w_[ind])));
|
28 |
+
k[i] = to_float(k_[ind]);
|
29 |
+
a[i] = to_float(a_[ind]);
|
30 |
+
b[i] = to_float(b_[ind]);
|
31 |
+
__syncthreads();
|
32 |
+
|
33 |
+
float sa = 0;
|
34 |
+
#pragma unroll
|
35 |
+
for (int j = 0; j < C; j++) {
|
36 |
+
sa += a[j] * state[j];
|
37 |
+
}
|
38 |
+
sa_[ind] = sa;
|
39 |
+
|
40 |
+
float v = to_float(v_[ind]);
|
41 |
+
float y = 0;
|
42 |
+
#pragma unroll
|
43 |
+
for (int j = 0; j < C; j++) {
|
44 |
+
float& s = state[j];
|
45 |
+
s = s * w[j] + sa * b[j] + k[j] * v;
|
46 |
+
y += s * q[j];
|
47 |
+
}
|
48 |
+
y_[ind] = to_bf(y);
|
49 |
+
|
50 |
+
if ((t+1)%_CHUNK_LEN_ == 0) {
|
51 |
+
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
|
52 |
+
#pragma unroll
|
53 |
+
for (int j = 0; j < C; j++) {
|
54 |
+
s_[base + j*C] = state[j];
|
55 |
+
}
|
56 |
+
}
|
57 |
+
}
|
58 |
+
|
59 |
+
#pragma unroll
|
60 |
+
for (int j = 0; j < C; j++) {
|
61 |
+
_state[s_idx+j] = state[j];
|
62 |
+
}
|
63 |
+
__syncthreads();
|
64 |
+
}
|
65 |
+
|
66 |
+
__global__ void backward_kernel(int T, int H, float*_state, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
|
67 |
+
constexpr int C = _C_;
|
68 |
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
69 |
+
|
70 |
+
float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
|
71 |
+
__shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
|
72 |
+
float qi, wi, ki, ai, bi, dyi;
|
73 |
+
|
74 |
+
for (int t = T-1; t >= 0; t--) {
|
75 |
+
int ind = bb*T*H*C + t*H*C + hh * C + i;
|
76 |
+
__syncthreads();
|
77 |
+
q[i] = qi = to_float(q_[ind]);
|
78 |
+
float wi_fac = -__expf(to_float(w_[ind]));
|
79 |
+
w[i] = wi = __expf(wi_fac);
|
80 |
+
k[i] = ki = to_float(k_[ind]);
|
81 |
+
a[i] = ai = to_float(a_[ind]);
|
82 |
+
b[i] = bi = to_float(b_[ind]);
|
83 |
+
v[i] = to_float(v_[ind]);
|
84 |
+
dy[i] = dyi = to_float(dy_[ind]);
|
85 |
+
sa[i] = sa_[ind];
|
86 |
+
__syncthreads();
|
87 |
+
|
88 |
+
if ((t+1)%_CHUNK_LEN_ == 0) {
|
89 |
+
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
|
90 |
+
#pragma unroll
|
91 |
+
for (int j = 0; j < C; j++) {
|
92 |
+
stateT[j] = s_[base + j];
|
93 |
+
}
|
94 |
+
}
|
95 |
+
|
96 |
+
float dq = 0;
|
97 |
+
#pragma unroll
|
98 |
+
for (int j = 0; j < C; j++) {
|
99 |
+
dq += stateT[j]*dy[j];
|
100 |
+
}
|
101 |
+
dq_[ind] = to_bf(dq);
|
102 |
+
|
103 |
+
float iwi = 1.0f/wi;
|
104 |
+
#pragma unroll
|
105 |
+
for (int j = 0; j < C; j++) {
|
106 |
+
stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
|
107 |
+
dstate[j] += dyi * q[j];
|
108 |
+
dstateT[j] += qi * dy[j];
|
109 |
+
}
|
110 |
+
|
111 |
+
float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
|
112 |
+
|
113 |
+
#pragma unroll
|
114 |
+
for (int j = 0; j < C; j++) {
|
115 |
+
dw += dstateT[j]*stateT[j];
|
116 |
+
dk += dstateT[j]*v[j];
|
117 |
+
dv += dstate[j]*k[j];
|
118 |
+
dSb += dstate[j]*b[j];
|
119 |
+
db += dstateT[j]*sa[j];
|
120 |
+
}
|
121 |
+
dw_[ind] = to_bf(dw * wi * wi_fac);
|
122 |
+
dk_[ind] = to_bf(dk);
|
123 |
+
dv_[ind] = to_bf(dv);
|
124 |
+
db_[ind] = to_bf(db);
|
125 |
+
|
126 |
+
__syncthreads();
|
127 |
+
dSb_shared[i] = dSb;
|
128 |
+
__syncthreads();
|
129 |
+
|
130 |
+
float da = 0;
|
131 |
+
|
132 |
+
#pragma unroll
|
133 |
+
for (int j = 0; j < C; j++) {
|
134 |
+
da += stateT[j]*dSb_shared[j];
|
135 |
+
}
|
136 |
+
da_[ind] = to_bf(da);
|
137 |
+
|
138 |
+
#pragma unroll
|
139 |
+
for (int j = 0; j < C; j++) {
|
140 |
+
dstate[j] = dstate[j]*w[j] + dSb * a[j];
|
141 |
+
dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
|
142 |
+
}
|
143 |
+
}
|
144 |
+
}
|
145 |
+
|
146 |
+
void cuda_forward(int B, int T, int H, float*_state, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) {
|
147 |
+
forward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,_state,w,q,k,v,z,a,y,s,sa);
|
148 |
+
}
|
149 |
+
void cuda_backward(int B, int T, int H, float*_state, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) {
|
150 |
+
assert(T%_CHUNK_LEN_ == 0);
|
151 |
+
backward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,_state,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da);
|
152 |
+
}
|
cuda/state_wkv7_op.cpp
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <cuda_bf16.h>
|
3 |
+
using bf = __nv_bfloat16;
|
4 |
+
|
5 |
+
void cuda_forward(int B, int T, int H, float*_state, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa);
|
6 |
+
|
7 |
+
void forward(torch::Tensor &_state, torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) {
|
8 |
+
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
|
9 |
+
cuda_forward(B, T, H, (float*)_state.data_ptr(), (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
|
10 |
+
}
|
11 |
+
|
12 |
+
void cuda_backward(int B, int T, int H, float*_state, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);
|
13 |
+
|
14 |
+
void backward(torch::Tensor &_state, torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
|
15 |
+
torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
|
16 |
+
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
|
17 |
+
cuda_backward(B, T, H, (float*)_state.data_ptr(), (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(),
|
18 |
+
(float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
|
19 |
+
}
|
20 |
+
|
21 |
+
TORCH_LIBRARY(state_wind_backstepping, m) {
|
22 |
+
m.def("forward(Tensor _state, Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()");
|
23 |
+
m.def("backward(Tensor _state, Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()");
|
24 |
+
}
|
25 |
+
|
26 |
+
TORCH_LIBRARY_IMPL(state_wind_backstepping, CUDA, m) {
|
27 |
+
m.impl("forward", &forward);
|
28 |
+
m.impl("backward", &backward);
|
29 |
+
}
|
30 |
+
|
31 |
+
// TORCH_LIBRARY(state_wind_backstepping, m) {
|
32 |
+
// m.def("forward", forward);
|
33 |
+
// m.def("backward", backward);
|
34 |
+
// }
|
cuda/wkv7_cuda.cu
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda_bf16.h>
|
2 |
+
#include <assert.h>
|
3 |
+
|
4 |
+
using bf = __nv_bfloat16;
|
5 |
+
__device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
|
6 |
+
__device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); }
|
7 |
+
|
8 |
+
typedef bf * __restrict__ F_;
|
9 |
+
|
10 |
+
__global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) {
|
11 |
+
constexpr int C = _C_;
|
12 |
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
13 |
+
|
14 |
+
float state[C] = {0};
|
15 |
+
__shared__ float q[C], k[C], w[C], a[C], b[C];
|
16 |
+
|
17 |
+
for (int t = 0; t < T; t++) {
|
18 |
+
int ind = bb*T*H*C + t*H*C + hh * C + i;
|
19 |
+
__syncthreads();
|
20 |
+
q[i] = to_float(q_[ind]);
|
21 |
+
w[i] = __expf(-__expf(to_float(w_[ind])));
|
22 |
+
k[i] = to_float(k_[ind]);
|
23 |
+
a[i] = to_float(a_[ind]);
|
24 |
+
b[i] = to_float(b_[ind]);
|
25 |
+
__syncthreads();
|
26 |
+
|
27 |
+
float sa = 0;
|
28 |
+
#pragma unroll
|
29 |
+
for (int j = 0; j < C; j++) {
|
30 |
+
sa += a[j] * state[j];
|
31 |
+
}
|
32 |
+
sa_[ind] = sa;
|
33 |
+
|
34 |
+
float v = to_float(v_[ind]);
|
35 |
+
float y = 0;
|
36 |
+
#pragma unroll
|
37 |
+
for (int j = 0; j < C; j++) {
|
38 |
+
float& s = state[j];
|
39 |
+
s = s * w[j] + sa * b[j] + k[j] * v;
|
40 |
+
y += s * q[j];
|
41 |
+
}
|
42 |
+
y_[ind] = to_bf(y);
|
43 |
+
|
44 |
+
if ((t+1)%_CHUNK_LEN_ == 0) {
|
45 |
+
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
|
46 |
+
#pragma unroll
|
47 |
+
for (int j = 0; j < C; j++) {
|
48 |
+
s_[base + j*C] = state[j];
|
49 |
+
}
|
50 |
+
}
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
__global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
|
55 |
+
constexpr int C = _C_;
|
56 |
+
int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
|
57 |
+
|
58 |
+
float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
|
59 |
+
__shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
|
60 |
+
float qi, wi, ki, ai, bi, dyi;
|
61 |
+
|
62 |
+
for (int t = T-1; t >= 0; t--) {
|
63 |
+
int ind = bb*T*H*C + t*H*C + hh * C + i;
|
64 |
+
__syncthreads();
|
65 |
+
q[i] = qi = to_float(q_[ind]);
|
66 |
+
float wi_fac = -__expf(to_float(w_[ind]));
|
67 |
+
w[i] = wi = __expf(wi_fac);
|
68 |
+
k[i] = ki = to_float(k_[ind]);
|
69 |
+
a[i] = ai = to_float(a_[ind]);
|
70 |
+
b[i] = bi = to_float(b_[ind]);
|
71 |
+
v[i] = to_float(v_[ind]);
|
72 |
+
dy[i] = dyi = to_float(dy_[ind]);
|
73 |
+
sa[i] = sa_[ind];
|
74 |
+
__syncthreads();
|
75 |
+
|
76 |
+
if ((t+1)%_CHUNK_LEN_ == 0) {
|
77 |
+
int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
|
78 |
+
#pragma unroll
|
79 |
+
for (int j = 0; j < C; j++) {
|
80 |
+
stateT[j] = s_[base + j];
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
float dq = 0;
|
85 |
+
#pragma unroll
|
86 |
+
for (int j = 0; j < C; j++) {
|
87 |
+
dq += stateT[j]*dy[j];
|
88 |
+
}
|
89 |
+
dq_[ind] = to_bf(dq);
|
90 |
+
|
91 |
+
float iwi = 1.0f/wi;
|
92 |
+
#pragma unroll
|
93 |
+
for (int j = 0; j < C; j++) {
|
94 |
+
stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
|
95 |
+
dstate[j] += dyi * q[j];
|
96 |
+
dstateT[j] += qi * dy[j];
|
97 |
+
}
|
98 |
+
|
99 |
+
float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
|
100 |
+
#pragma unroll
|
101 |
+
for (int j = 0; j < C; j++) {
|
102 |
+
dw += dstateT[j]*stateT[j];
|
103 |
+
dk += dstateT[j]*v[j];
|
104 |
+
dv += dstate[j]*k[j];
|
105 |
+
dSb += dstate[j]*b[j];
|
106 |
+
db += dstateT[j]*sa[j];
|
107 |
+
}
|
108 |
+
dw_[ind] = to_bf(dw * wi * wi_fac);
|
109 |
+
dk_[ind] = to_bf(dk);
|
110 |
+
dv_[ind] = to_bf(dv);
|
111 |
+
db_[ind] = to_bf(db);
|
112 |
+
|
113 |
+
__syncthreads();
|
114 |
+
dSb_shared[i] = dSb;
|
115 |
+
__syncthreads();
|
116 |
+
|
117 |
+
float da = 0;
|
118 |
+
#pragma unroll
|
119 |
+
for (int j = 0; j < C; j++) {
|
120 |
+
da += stateT[j]*dSb_shared[j];
|
121 |
+
}
|
122 |
+
da_[ind] = to_bf(da);
|
123 |
+
|
124 |
+
#pragma unroll
|
125 |
+
for (int j = 0; j < C; j++) {
|
126 |
+
dstate[j] = dstate[j]*w[j] + dSb * a[j];
|
127 |
+
dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
|
128 |
+
}
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) {
|
133 |
+
forward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,y,s,sa);
|
134 |
+
}
|
135 |
+
void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) {
|
136 |
+
assert(T%_CHUNK_LEN_ == 0);
|
137 |
+
backward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da);
|
138 |
+
}
|
cuda/wkv7_op.cpp
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <cuda_bf16.h>
|
3 |
+
using bf = __nv_bfloat16;
|
4 |
+
|
5 |
+
void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa);
|
6 |
+
|
7 |
+
void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) {
|
8 |
+
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
|
9 |
+
cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
|
10 |
+
}
|
11 |
+
|
12 |
+
void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);
|
13 |
+
|
14 |
+
void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
|
15 |
+
torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
|
16 |
+
int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
|
17 |
+
cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(),
|
18 |
+
(float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
|
19 |
+
}
|
20 |
+
|
21 |
+
TORCH_LIBRARY(wind_backstepping, m) {
|
22 |
+
m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()");
|
23 |
+
m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()");
|
24 |
+
}
|
25 |
+
|
26 |
+
TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {
|
27 |
+
m.impl("forward", &forward);
|
28 |
+
m.impl("backward", &backward);
|
29 |
+
}
|
30 |
+
|
31 |
+
// TORCH_LIBRARY(wind_backstepping, m) {
|
32 |
+
// m.def("forward", forward);
|
33 |
+
// m.def("backward", backward);
|
34 |
+
// }
|
modeling_blocks_rwkv7.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
modeling_rwkv7.py
ADDED
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" RWKV Modeling"""
|
2 |
+
|
3 |
+
from transformers.modeling_utils import PreTrainedModel
|
4 |
+
from transformers.utils import (
|
5 |
+
ModelOutput,
|
6 |
+
add_code_sample_docstrings,
|
7 |
+
add_start_docstrings,
|
8 |
+
add_start_docstrings_to_model_forward,
|
9 |
+
is_ninja_available,
|
10 |
+
is_torch_cuda_available,
|
11 |
+
logging,
|
12 |
+
)
|
13 |
+
from transformers.generation import GenerationMixin
|
14 |
+
from transformers.modeling_outputs import ModelOutput
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
from torch.nn import CrossEntropyLoss
|
19 |
+
import torch.nn.functional as F
|
20 |
+
|
21 |
+
import warnings
|
22 |
+
from dataclasses import dataclass
|
23 |
+
from typing import List, Dict, Optional, Tuple, Union, Any
|
24 |
+
|
25 |
+
# Load the RWKV7Config and RWKV7GooseModel
|
26 |
+
from .configuration_rwkv7 import RWKV7Config
|
27 |
+
from .modeling_blocks_rwkv7 import RWKV7GooseModel
|
28 |
+
|
29 |
+
class RWKV7PreTrainedModel(PreTrainedModel,RWKV7GooseModel):
|
30 |
+
"""
|
31 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
|
32 |
+
"""
|
33 |
+
config_class = RWKV7Config
|
34 |
+
base_model_prefix = "rwkv7"
|
35 |
+
is_parallelizable = True
|
36 |
+
_no_split_modules = ["RWKV7LayerBlock"]
|
37 |
+
_keep_in_fp32_modules = []
|
38 |
+
supports_gradient_checkpointing = True
|
39 |
+
|
40 |
+
def __init__(self, config: RWKV7Config):
|
41 |
+
RWKV7GooseModel.__init__(self, config.__dict__)
|
42 |
+
self.config = config
|
43 |
+
|
44 |
+
def _init_weights(
|
45 |
+
self,
|
46 |
+
module
|
47 |
+
):
|
48 |
+
# Fallback to the default init weights
|
49 |
+
if hasattr(module, 'reset_parameters'):
|
50 |
+
module.reset_parameters()
|
51 |
+
return
|
52 |
+
elif hasattr(module, 'init_parameters'):
|
53 |
+
module.init_parameters()
|
54 |
+
return
|
55 |
+
|
56 |
+
# Default FP initializer_range for Linear / LN layers
|
57 |
+
initializer_range = 0.02
|
58 |
+
|
59 |
+
if isinstance(module, (nn.ParameterList, nn.ModuleList)):
|
60 |
+
# Iterate and initialize each parameter
|
61 |
+
for param in module:
|
62 |
+
self._init_weights(param)
|
63 |
+
elif isinstance(module, nn.ParameterDict):
|
64 |
+
# Iterate and initialize each parameter
|
65 |
+
for key, param in module.items():
|
66 |
+
self._init_weights(param)
|
67 |
+
|
68 |
+
elif isinstance(module, (nn.Linear, nn.Conv1d)):
|
69 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
70 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
71 |
+
nn.init.normal_(module.weight, mean=0.0, std=initializer_range)
|
72 |
+
if module.bias is not None:
|
73 |
+
nn.init.zeros_(module.bias)
|
74 |
+
elif isinstance(module, nn.LayerNorm):
|
75 |
+
nn.init.normal_(module.weight, mean=0.0, std=initializer_range)
|
76 |
+
elif isinstance(module, nn.Parameter):
|
77 |
+
nn.init.normal_(module, mean=0.0, std=initializer_range)
|
78 |
+
elif isinstance(module, nn.Embedding):
|
79 |
+
nn.init.normal_(module.weight, mean=0.0, std=initializer_range)
|
80 |
+
|
81 |
+
# # RWKV does not use a blank pad idx. The pad_idx is a training token
|
82 |
+
# if module.padding_idx is not None:
|
83 |
+
# module.weight.data[module.padding_idx].zero_()
|
84 |
+
|
85 |
+
@dataclass
|
86 |
+
class RWKV7Output(ModelOutput):
|
87 |
+
"""
|
88 |
+
Class for the RWKV model outputs.
|
89 |
+
Args:
|
90 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
91 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
92 |
+
state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
|
93 |
+
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
94 |
+
avoid providing the old `input_ids`.
|
95 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
96 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
97 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
|
98 |
+
the model at the output of each layer plus the optional initial embedding outputs.
|
99 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
100 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
101 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
102 |
+
the self-attention heads.
|
103 |
+
"""
|
104 |
+
last_hidden_state: torch.FloatTensor = None
|
105 |
+
rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None
|
106 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
107 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
108 |
+
|
109 |
+
|
110 |
+
@dataclass
|
111 |
+
class RWKV7CausalLMOutput(ModelOutput):
|
112 |
+
"""
|
113 |
+
Base class for causal language model (or autoregressive) outputs.
|
114 |
+
Args:
|
115 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
116 |
+
Language modeling loss (for next-token prediction).
|
117 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
118 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
119 |
+
state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
|
120 |
+
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
|
121 |
+
avoid providing the old `input_ids`.
|
122 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
123 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
124 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
|
125 |
+
the model at the output of each layer plus the optional initial embedding outputs.
|
126 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
127 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
128 |
+
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
|
129 |
+
the self-attention heads.
|
130 |
+
"""
|
131 |
+
loss: Optional[torch.FloatTensor] = None
|
132 |
+
logits: torch.FloatTensor = None
|
133 |
+
rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None
|
134 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
135 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
136 |
+
|
137 |
+
RWKV7_START_DOCSTRING = r"""
|
138 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
139 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
140 |
+
etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
|
141 |
+
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
142 |
+
general usage and behavior.
|
143 |
+
Parameters:
|
144 |
+
config ([`Rwkv7Config`]): Model configuration class with all the parameters of the model.
|
145 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
146 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
147 |
+
"""
|
148 |
+
|
149 |
+
RWKV7_INPUTS_DOCSTRING = r"""
|
150 |
+
Args:
|
151 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
152 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
153 |
+
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
154 |
+
sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their
|
155 |
+
past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See
|
156 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
|
157 |
+
IDs?](../glossary#input-ids)
|
158 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
159 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
160 |
+
- 1 for tokens that are **not masked**,
|
161 |
+
- 0 for tokens that are **masked**.
|
162 |
+
[What are attention masks?](../glossary#attention-mask)
|
163 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
164 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
165 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
166 |
+
model's internal embedding lookup matrix.
|
167 |
+
|
168 |
+
state (List block states, representing the RWKV various internal states per layer `(batch_size, hidden_state)`, *optional*):
|
169 |
+
If passed along, the model uses the previous state in all the blocks (which will give the output for the
|
170 |
+
`input_ids` provided as if the model add `state_input_ids + input_ids` as context).
|
171 |
+
|
172 |
+
use_cache (`bool`, *optional*):
|
173 |
+
If set to `True`, the last state is returned and can be used to quickly generate the next logits.
|
174 |
+
output_attentions (`bool`, *optional*):
|
175 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
176 |
+
tensors for more detail.
|
177 |
+
output_hidden_states (`bool`, *optional*):
|
178 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
179 |
+
more detail.
|
180 |
+
|
181 |
+
return_dict (`bool`, *optional*):
|
182 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
183 |
+
"""
|
184 |
+
|
185 |
+
@add_start_docstrings(
|
186 |
+
"The bare RWKV7 Model transformer outputting raw hidden-states without activating the head (variable is still declared)",
|
187 |
+
RWKV7_START_DOCSTRING,
|
188 |
+
)
|
189 |
+
class RWKV7Model(RWKV7PreTrainedModel):
|
190 |
+
def __init__(self, config: RWKV7Config):
|
191 |
+
super().__init__(config)
|
192 |
+
|
193 |
+
def get_input_embeddings(self):
|
194 |
+
return self.emb
|
195 |
+
def set_input_embeddings(self, value):
|
196 |
+
self.emb = value
|
197 |
+
|
198 |
+
def get_output_embeddings(self):
|
199 |
+
return self.lm_head
|
200 |
+
def set_output_embeddings(self, new_embeddings):
|
201 |
+
self.lm_head = new_embeddings
|
202 |
+
|
203 |
+
@add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
|
204 |
+
@add_code_sample_docstrings(
|
205 |
+
output_type=RWKV7Output,
|
206 |
+
)
|
207 |
+
def forward(
|
208 |
+
self,
|
209 |
+
input_ids: Optional[torch.LongTensor] = None,
|
210 |
+
attention_mask: Optional[torch.LongTensor] = None, # not in use
|
211 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
212 |
+
rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None,
|
213 |
+
use_cache: Optional[bool] = None,
|
214 |
+
output_attentions: Optional[bool] = None,
|
215 |
+
output_hidden_states: Optional[bool] = None,
|
216 |
+
return_dict: Optional[bool] = None,
|
217 |
+
**kwargs
|
218 |
+
) -> Union[Tuple, RWKV7Output]:
|
219 |
+
|
220 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
221 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
222 |
+
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
|
223 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
224 |
+
|
225 |
+
if output_attentions:
|
226 |
+
warnings.warning_once("`RWKV7Model` does not `output_attentions` now, setting it to `False`.")
|
227 |
+
output_attentions = False
|
228 |
+
|
229 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
230 |
+
warnings.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
231 |
+
use_cache = False
|
232 |
+
|
233 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
234 |
+
warnings.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
235 |
+
use_cache = False
|
236 |
+
|
237 |
+
if output_hidden_states:
|
238 |
+
warnings.warning_once("`RWKV7Model` does not `output_hidden_states` now, setting it to `False`.")
|
239 |
+
output_hidden_states = False
|
240 |
+
|
241 |
+
# ---
|
242 |
+
|
243 |
+
# Compute the input embeddings
|
244 |
+
if input_ids is not None and inputs_embeds is not None:
|
245 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
246 |
+
if input_ids is None and inputs_embeds is None:
|
247 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
248 |
+
if inputs_embeds is None:
|
249 |
+
inputs_embeds = self.emb(input_ids.to(self.emb.weight.device))
|
250 |
+
x_hidden_state = inputs_embeds
|
251 |
+
|
252 |
+
# Initialize the rwkv_state / prv_stateList
|
253 |
+
if rwkv_state is None or use_cache == False:
|
254 |
+
rwkv_state = self.get_init_state(batch_size=x_hidden_state.shape[0])
|
255 |
+
prv_stateList = rwkv_state
|
256 |
+
|
257 |
+
# Initialize the ret_stateList
|
258 |
+
ret_stateList = self.get_init_state(batch_size=x_hidden_state.shape[0], skip_init_state=True)
|
259 |
+
|
260 |
+
all_hidden_states = () if output_hidden_states else None
|
261 |
+
all_attns = () if output_attentions else None
|
262 |
+
v_first = None
|
263 |
+
ret_sublist = None
|
264 |
+
|
265 |
+
# Lets start iterating the blocks
|
266 |
+
for i, block in enumerate(self.blocks):
|
267 |
+
# Build the full inner hidden state
|
268 |
+
if output_hidden_states:
|
269 |
+
all_hidden_states += (x_hidden_state,)
|
270 |
+
|
271 |
+
# Forward the block
|
272 |
+
if self.gradient_checkpointing and self.training:
|
273 |
+
x_hidden_state, ret_sublist, v_first = self._gradient_checkpointing_func(
|
274 |
+
block.__call__, x_hidden_state, prv_stateList[i], v_first
|
275 |
+
)
|
276 |
+
ret_stateList[i] = ret_sublist
|
277 |
+
else:
|
278 |
+
x_hidden_state, ret_sublist, v_first = block(x_hidden_state, prv_stateList[i], v_first)
|
279 |
+
ret_stateList[i] = ret_sublist
|
280 |
+
|
281 |
+
# if output_attentions:
|
282 |
+
# all_attns += (ret_sublist,)
|
283 |
+
|
284 |
+
# Final layer norm
|
285 |
+
x_hidden_state = x_hidden_state.to(self.ln_out.weight.device, non_blocking=True)
|
286 |
+
x_hidden_state = self.ln_out(x_hidden_state)
|
287 |
+
|
288 |
+
# add hidden states from the last decoder layer
|
289 |
+
if output_hidden_states:
|
290 |
+
all_hidden_states += (x_hidden_state,)
|
291 |
+
|
292 |
+
if not return_dict:
|
293 |
+
return tuple(i for i in [x_hidden_state, rwkv_state, all_hidden_states, all_attns] if i is not None)
|
294 |
+
return RWKV7Output(
|
295 |
+
last_hidden_state=x_hidden_state,
|
296 |
+
rwkv_state=rwkv_state,
|
297 |
+
hidden_states=all_hidden_states,
|
298 |
+
attentions=all_attns
|
299 |
+
)
|
300 |
+
|
301 |
+
@add_start_docstrings(
|
302 |
+
"""
|
303 |
+
The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
304 |
+
embeddings).
|
305 |
+
""",
|
306 |
+
RWKV7_START_DOCSTRING,
|
307 |
+
)
|
308 |
+
class RWKV7ForCausalLM(RWKV7Model, GenerationMixin):
|
309 |
+
|
310 |
+
def __init__(self, config):
|
311 |
+
super().__init__(config)
|
312 |
+
self.post_init()
|
313 |
+
|
314 |
+
def prepare_inputs_for_generation(
|
315 |
+
self,
|
316 |
+
input_ids=None,
|
317 |
+
attention_mask: Optional[torch.Tensor] = None,
|
318 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
319 |
+
use_cache: bool = True,
|
320 |
+
rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None,
|
321 |
+
# num_new_tokens_if_rwkv_state: int = 1, # Only triggers if given input_ids + rwkv_state
|
322 |
+
num_logits_to_keep: Optional[int] = None,
|
323 |
+
**kwargs
|
324 |
+
):
|
325 |
+
'''
|
326 |
+
Personal Notes: On huggingface barely documented "Transformer" hooks.
|
327 |
+
|
328 |
+
I assume this is triggered once, for the start of AI inference.
|
329 |
+
With subsequent calls for forward on each token step, being updated with
|
330 |
+
`_update_model_kwargs_for_generation` function instead?
|
331 |
+
'''
|
332 |
+
# # only last token for `inputs_ids` if the `past_key_values` is passed along.
|
333 |
+
# if rwkv_state is not None and input_ids is not None:
|
334 |
+
# input_ids = input_ids[:, -num_new_tokens_if_rwkv_state:]
|
335 |
+
|
336 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
337 |
+
if inputs_embeds is not None:
|
338 |
+
if input_ids is not None:
|
339 |
+
raise ValueError("You cannot specify both `inputs_ids` and `inputs_embeds` at the same time")
|
340 |
+
model_inputs = {'inputs_embeds': inputs_embeds}
|
341 |
+
else:
|
342 |
+
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
343 |
+
# recompiles graphs as the stride of the inputs is a guard.
|
344 |
+
# Ref: https://github.com/huggingface/transformers/pull/29114
|
345 |
+
# TODO: use `next_tokens` directly instead.
|
346 |
+
model_inputs = {'input_ids': input_ids.contiguous()}
|
347 |
+
|
348 |
+
if num_logits_to_keep is not None:
|
349 |
+
model_inputs['num_logits_to_keep'] = num_logits_to_keep
|
350 |
+
|
351 |
+
model_inputs.update({
|
352 |
+
'rwkv_state': rwkv_state,
|
353 |
+
'use_cache': use_cache,
|
354 |
+
'attention_mask': attention_mask,
|
355 |
+
'num_logits_to_keep': num_logits_to_keep,
|
356 |
+
})
|
357 |
+
return model_inputs
|
358 |
+
|
359 |
+
def _update_model_kwargs_for_generation(
|
360 |
+
self, outputs: ModelOutput,
|
361 |
+
model_kwargs: Dict[str, Any],
|
362 |
+
num_new_tokens: int = 1,
|
363 |
+
**kwargs
|
364 |
+
) -> Dict[str, Any]:
|
365 |
+
# Overwritten -- this model uses `state`, but doesn't have a cache (`past_key_values`)
|
366 |
+
rwkv_state = outputs.get("rwkv_state", None)
|
367 |
+
input_ids = model_kwargs.get("input_ids", None)
|
368 |
+
attention_mask = model_kwargs.get("attention_mask", None)
|
369 |
+
|
370 |
+
# only last token for inputs_ids if the state is passed along.
|
371 |
+
if rwkv_state is not None and input_ids is not None and num_new_tokens > 0:
|
372 |
+
input_ids = input_ids[:, -num_new_tokens:]
|
373 |
+
model_kwargs["input_ids"] = input_ids
|
374 |
+
|
375 |
+
if attention_mask is not None:
|
376 |
+
attention_mask = attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))
|
377 |
+
model_kwargs["attention_mask"] = attention_mask
|
378 |
+
|
379 |
+
# Return the formated output
|
380 |
+
return model_kwargs
|
381 |
+
|
382 |
+
@add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
|
383 |
+
@add_code_sample_docstrings(
|
384 |
+
output_type=RWKV7CausalLMOutput,
|
385 |
+
)
|
386 |
+
def forward(
|
387 |
+
self,
|
388 |
+
input_ids: Optional[torch.LongTensor] = None,
|
389 |
+
attention_mask: Optional[torch.LongTensor] = None, # noqa
|
390 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
391 |
+
labels: Optional[torch.LongTensor] = None,
|
392 |
+
rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None,
|
393 |
+
use_cache: Optional[bool] = None,
|
394 |
+
output_attentions: Optional[bool] = None,
|
395 |
+
output_hidden_states: Optional[bool] = None,
|
396 |
+
return_dict: Optional[bool] = None,
|
397 |
+
**kwargs
|
398 |
+
) -> Union[Tuple, RWKV7CausalLMOutput]:
|
399 |
+
r"""
|
400 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
401 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
402 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
403 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
404 |
+
"""
|
405 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
406 |
+
|
407 |
+
rwkv_outputs = RWKV7Model.forward(
|
408 |
+
self, input_ids, attention_mask, inputs_embeds,
|
409 |
+
rwkv_state, use_cache, output_attentions, output_hidden_states,
|
410 |
+
return_dict=False
|
411 |
+
)
|
412 |
+
|
413 |
+
# Get the hidden state, and the updated RWKV state
|
414 |
+
hidden_states = rwkv_outputs[0]
|
415 |
+
rwkv_state = rwkv_outputs[1]
|
416 |
+
|
417 |
+
# Get the ALL hidden states and attentions dumps
|
418 |
+
all_hidden_states = rwkv_outputs[2] if output_hidden_states else None
|
419 |
+
if output_hidden_states:
|
420 |
+
all_attns = rwkv_outputs[3] if output_attentions else None
|
421 |
+
else:
|
422 |
+
all_attns = rwkv_outputs[2] if output_attentions else None
|
423 |
+
|
424 |
+
# Forward the head state
|
425 |
+
logits = self.head(hidden_states)
|
426 |
+
|
427 |
+
# Compute the loss from the labels
|
428 |
+
loss = None
|
429 |
+
if labels is not None:
|
430 |
+
|
431 |
+
# Setup loss function
|
432 |
+
if self._loss_function_cache is None:
|
433 |
+
self._loss_function_cache = CrossEntropyLoss()
|
434 |
+
|
435 |
+
# move labels to correct device to enable model parallelism
|
436 |
+
labels = labels.to(logits.device)
|
437 |
+
# Shift so that tokens < n predict n
|
438 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
439 |
+
shift_labels = labels[..., 1:].contiguous()
|
440 |
+
# Compute the token loss
|
441 |
+
|
442 |
+
if attention_mask is not None:
|
443 |
+
token_loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="none")
|
444 |
+
submask = attention_mask[..., 1:].contiguous().view(-1)
|
445 |
+
loss = (token_loss * submask).sum() / submask.sum()
|
446 |
+
else:
|
447 |
+
loss = F.cross_entropy(shift_logits.view(-1, shift_labels.size(-1)), shift_labels.view(-1), reduction="mean")
|
448 |
+
|
449 |
+
if not return_dict:
|
450 |
+
return tuple(i for i in [loss, logits, rwkv_state, all_hidden_states, all_attns] if i is not None)
|
451 |
+
|
452 |
+
return RWKV7CausalLMOutput(
|
453 |
+
loss=loss,
|
454 |
+
logits=logits,
|
455 |
+
rwkv_state=rwkv_state,
|
456 |
+
hidden_states=all_hidden_states,
|
457 |
+
attentions=all_attns,
|
458 |
+
)
|
459 |
+
|
460 |
+
__all__ = ["RWKV7ForCausalLM", "RWKV7Model", "RWKV7PreTrainedModel"]
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "tokenizer_class": "GPTNeoXTokenizer"}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|