picocreator commited on
Commit
33b8599
·
verified ·
1 Parent(s): fb27943

Upload 13 files

Browse files
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RWKV7ForCasualLM",
4
+ "RWKV7Model",
5
+ "RWKV7PreTrainedModel"
6
+ ],
7
+ "bos_token_id": 0,
8
+ "device": null,
9
+ "dropout_rate": 0.0,
10
+ "dtype": null,
11
+ "eos_token_id": 0,
12
+ "head_size": 64,
13
+ "hidden_size": 2048,
14
+ "hidden_size_att": 2048,
15
+ "hidden_size_ffn": 8192,
16
+ "init_state_wkv": false,
17
+ "layer_id": null,
18
+ "model_type": "rwkv7",
19
+ "num_hidden_layers": 24,
20
+ "tie_word_embeddings": false,
21
+ "tmix_backend": "auto",
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.48.0",
24
+ "vocab_size": 50304,
25
+ "auto_map": {
26
+ "AutoConfig": "configuration_rwkv7.RWKV7Config",
27
+ "AutoModel": "modeling_rwkv7.RWKV7Model",
28
+ "AutoModelForCasualLM": "modeling_rwkv7.RWKV7ForCasualLM"
29
+ }
30
+ }
configuration_rwkv7.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ RWKV configuration"""
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ # from transformers.utils import logging
5
+ # logger = logging.get_logger(__name__)
6
+
7
+ # Import the dependencies
8
+ from .modeling_blocks_rwkv7 import RWKV7GooseConfigMap
9
+
10
+ class RWKV7Config(PretrainedConfig):
11
+ """
12
+ This is the configuration class to store the configuration of a [`Rwkv7Model`]. It is used to instantiate a RWKV7
13
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
14
+ defaults will yield a similar configuration to that of the RWVK-7
15
+ [RWKV/v7-Goose-1.6B-Pile-HF](https://huggingface.co/RWKV/v7-Goose-1.6B-Pile-HF) architecture.
16
+
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+
20
+ Args:
21
+ vocab_size (`int`, *optional*, defaults to 65536):
22
+ Vocabulary size of the RWKV7 model. Defines the number of different tokens that can be represented by the
23
+ `inputs_ids` passed when calling [`Rwkv7Model`].
24
+ num_hidden_layers (`int`, *optional*, defaults to 24):
25
+ Number of hidden layers in the model.
26
+ hidden_size (`int`, *optional*, defaults to 768):
27
+ Dimensionality of the embeddings and hidden states.
28
+
29
+ hidden_size_att (`int`, *optional*):
30
+ Dimensionality of the attention hidden states. Will be computed from `hidden_size` if unset.
31
+ hidden_size_ffn (`int`, *optional*):
32
+ Dimensionality of the FFN hidden states. Will be computed from `hidden_size` if unset.
33
+ head_size (`int`, *optional*, defaults to 64):
34
+ head_size of rwkv7 self_attention module.
35
+ tmix_backend (`str`, *optional*, defaults to "auto"):
36
+ Backend to use for the time mix module. "auto" defaults to "pytorch" if the device is "cpu" and "cuda" otherwise.
37
+ (Valid values: "auto", "pytorch", "cuda", "triton", "triton_bighead", "fla", "fla_fused", "pytorch_ref", "pytorch_ref_fp32")
38
+ init_state_wkv (`bool`, *optional*, defaults to `False`):
39
+ Whether to initialize the wkv state in the model. Used for WKV state tuning.
40
+
41
+ device (`str`, *optional*):
42
+ Device to use for the model. Use the respective torch.device types
43
+ dtype (`str`, *optional*):
44
+ Model weights data type. Use the respective torch.dtype types
45
+
46
+ bos_token_id (`int`, *optional*, defaults to 0):
47
+ The id of the beginning of sentence token in the vocabulary. Defaults to 0.
48
+ eos_token_id (`int`, *optional*, defaults to 0):
49
+ The id of the end of sentence token in the vocabulary. Defaults to 0.
50
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
51
+ Whether or not to tie the word embeddings with the input token embeddings.
52
+ (this value is currently ignored in our implementation)
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import Rwkv7Config, Rwkv7Model
58
+
59
+ >>> # Initializing a Rwkv7 configuration
60
+ >>> configuration = Rwkv7Config()
61
+
62
+ >>> # Initializing a model (with random weights) from the configuration
63
+ >>> model = Rwkv7Model(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+
69
+ model_type = "rwkv7"
70
+
71
+ def __init__(
72
+ self,
73
+ ########################################
74
+ # Vocab, layer count, and hidden size
75
+ vocab_size=65536,
76
+ num_hidden_layers=24,
77
+ hidden_size=768,
78
+ # Optional hidden sizes
79
+ hidden_size_att=None,
80
+ hidden_size_ffn=None,
81
+ # Headsize, timemix backend
82
+ head_size=64,
83
+ tmix_backend="auto",
84
+ init_state_wkv=False,
85
+ # Trainer model configs
86
+ dropout_rate=0.0,
87
+ # Torch device and dtype
88
+ device=None,
89
+ dtype=None,
90
+ # Tokenizer related settings in HF configuration
91
+ bos_token_id=0,
92
+ eos_token_id=0,
93
+ tie_word_embeddings=False,
94
+ ########################################
95
+ **kwargs,
96
+ ):
97
+ # Normalize dtype if torch_dtype is set within kwargs
98
+ if dtype is None and "torch_dtype" in kwargs:
99
+ dtype = kwargs["torch_dtype"]
100
+
101
+ self.vocab_size = vocab_size
102
+ self.num_hidden_layers = num_hidden_layers
103
+ self.hidden_size = hidden_size
104
+ self.hidden_size_att = hidden_size_att
105
+ self.hidden_size_ffn = hidden_size_ffn
106
+
107
+ self.head_size = head_size
108
+ self.tmix_backend = tmix_backend
109
+ self.init_state_wkv = init_state_wkv
110
+
111
+ self.device = device
112
+ self.dtype = dtype
113
+
114
+ self.dropout_rate = dropout_rate
115
+
116
+ # Forward to the HF PretrainedConfig
117
+ super().__init__(
118
+ tie_word_embeddings=tie_word_embeddings,
119
+ bos_token_id=bos_token_id,
120
+ eos_token_id=eos_token_id,
121
+ **kwargs
122
+ )
123
+
124
+
125
+ @staticmethod
126
+ def from_model_state_dict(state_dict: dict, **kwargs):
127
+ goose_config = RWKV7GooseConfigMap.from_model_state_dict(state_dict)
128
+ # Join dictionary with **goose_config.__dict__ and **kwargs
129
+ return RWKV7Config(**{**goose_config.__dict__, **kwargs})
cuda/state_wkv7_cuda.cu ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda_bf16.h>
2
+ #include <assert.h>
3
+
4
+ using bf = __nv_bfloat16;
5
+ __device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
6
+ __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); }
7
+
8
+ typedef bf * __restrict__ F_;
9
+
10
+ __global__ void forward_kernel(int T, int H, float*_state, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) {
11
+ constexpr int C = _C_;
12
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
13
+
14
+ float state[C] = {0};
15
+ int s_idx = bb*H*C*C + hh*C*C + i*C;
16
+ #pragma unroll
17
+ for (int j = 0; j < C; j++) {
18
+ state[j] = _state[s_idx+j];
19
+ }
20
+
21
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
22
+
23
+ for (int t = 0; t < T; t++) {
24
+ int ind = bb*T*H*C + t*H*C + hh * C + i;
25
+ __syncthreads();
26
+ q[i] = to_float(q_[ind]);
27
+ w[i] = __expf(-__expf(to_float(w_[ind])));
28
+ k[i] = to_float(k_[ind]);
29
+ a[i] = to_float(a_[ind]);
30
+ b[i] = to_float(b_[ind]);
31
+ __syncthreads();
32
+
33
+ float sa = 0;
34
+ #pragma unroll
35
+ for (int j = 0; j < C; j++) {
36
+ sa += a[j] * state[j];
37
+ }
38
+ sa_[ind] = sa;
39
+
40
+ float v = to_float(v_[ind]);
41
+ float y = 0;
42
+ #pragma unroll
43
+ for (int j = 0; j < C; j++) {
44
+ float& s = state[j];
45
+ s = s * w[j] + sa * b[j] + k[j] * v;
46
+ y += s * q[j];
47
+ }
48
+ y_[ind] = to_bf(y);
49
+
50
+ if ((t+1)%_CHUNK_LEN_ == 0) {
51
+ int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
52
+ #pragma unroll
53
+ for (int j = 0; j < C; j++) {
54
+ s_[base + j*C] = state[j];
55
+ }
56
+ }
57
+ }
58
+
59
+ #pragma unroll
60
+ for (int j = 0; j < C; j++) {
61
+ _state[s_idx+j] = state[j];
62
+ }
63
+ __syncthreads();
64
+ }
65
+
66
+ __global__ void backward_kernel(int T, int H, float*_state, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
67
+ constexpr int C = _C_;
68
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
69
+
70
+ float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
71
+ __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
72
+ float qi, wi, ki, ai, bi, dyi;
73
+
74
+ for (int t = T-1; t >= 0; t--) {
75
+ int ind = bb*T*H*C + t*H*C + hh * C + i;
76
+ __syncthreads();
77
+ q[i] = qi = to_float(q_[ind]);
78
+ float wi_fac = -__expf(to_float(w_[ind]));
79
+ w[i] = wi = __expf(wi_fac);
80
+ k[i] = ki = to_float(k_[ind]);
81
+ a[i] = ai = to_float(a_[ind]);
82
+ b[i] = bi = to_float(b_[ind]);
83
+ v[i] = to_float(v_[ind]);
84
+ dy[i] = dyi = to_float(dy_[ind]);
85
+ sa[i] = sa_[ind];
86
+ __syncthreads();
87
+
88
+ if ((t+1)%_CHUNK_LEN_ == 0) {
89
+ int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
90
+ #pragma unroll
91
+ for (int j = 0; j < C; j++) {
92
+ stateT[j] = s_[base + j];
93
+ }
94
+ }
95
+
96
+ float dq = 0;
97
+ #pragma unroll
98
+ for (int j = 0; j < C; j++) {
99
+ dq += stateT[j]*dy[j];
100
+ }
101
+ dq_[ind] = to_bf(dq);
102
+
103
+ float iwi = 1.0f/wi;
104
+ #pragma unroll
105
+ for (int j = 0; j < C; j++) {
106
+ stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
107
+ dstate[j] += dyi * q[j];
108
+ dstateT[j] += qi * dy[j];
109
+ }
110
+
111
+ float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
112
+
113
+ #pragma unroll
114
+ for (int j = 0; j < C; j++) {
115
+ dw += dstateT[j]*stateT[j];
116
+ dk += dstateT[j]*v[j];
117
+ dv += dstate[j]*k[j];
118
+ dSb += dstate[j]*b[j];
119
+ db += dstateT[j]*sa[j];
120
+ }
121
+ dw_[ind] = to_bf(dw * wi * wi_fac);
122
+ dk_[ind] = to_bf(dk);
123
+ dv_[ind] = to_bf(dv);
124
+ db_[ind] = to_bf(db);
125
+
126
+ __syncthreads();
127
+ dSb_shared[i] = dSb;
128
+ __syncthreads();
129
+
130
+ float da = 0;
131
+
132
+ #pragma unroll
133
+ for (int j = 0; j < C; j++) {
134
+ da += stateT[j]*dSb_shared[j];
135
+ }
136
+ da_[ind] = to_bf(da);
137
+
138
+ #pragma unroll
139
+ for (int j = 0; j < C; j++) {
140
+ dstate[j] = dstate[j]*w[j] + dSb * a[j];
141
+ dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
142
+ }
143
+ }
144
+ }
145
+
146
+ void cuda_forward(int B, int T, int H, float*_state, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) {
147
+ forward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,_state,w,q,k,v,z,a,y,s,sa);
148
+ }
149
+ void cuda_backward(int B, int T, int H, float*_state, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) {
150
+ assert(T%_CHUNK_LEN_ == 0);
151
+ backward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,_state,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da);
152
+ }
cuda/state_wkv7_op.cpp ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <cuda_bf16.h>
3
+ using bf = __nv_bfloat16;
4
+
5
+ void cuda_forward(int B, int T, int H, float*_state, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa);
6
+
7
+ void forward(torch::Tensor &_state, torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) {
8
+ int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
9
+ cuda_forward(B, T, H, (float*)_state.data_ptr(), (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
10
+ }
11
+
12
+ void cuda_backward(int B, int T, int H, float*_state, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);
13
+
14
+ void backward(torch::Tensor &_state, torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
15
+ torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
16
+ int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
17
+ cuda_backward(B, T, H, (float*)_state.data_ptr(), (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(),
18
+ (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
19
+ }
20
+
21
+ TORCH_LIBRARY(state_wind_backstepping, m) {
22
+ m.def("forward(Tensor _state, Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()");
23
+ m.def("backward(Tensor _state, Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()");
24
+ }
25
+
26
+ TORCH_LIBRARY_IMPL(state_wind_backstepping, CUDA, m) {
27
+ m.impl("forward", &forward);
28
+ m.impl("backward", &backward);
29
+ }
30
+
31
+ // TORCH_LIBRARY(state_wind_backstepping, m) {
32
+ // m.def("forward", forward);
33
+ // m.def("backward", backward);
34
+ // }
cuda/wkv7_cuda.cu ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda_bf16.h>
2
+ #include <assert.h>
3
+
4
+ using bf = __nv_bfloat16;
5
+ __device__ inline float to_float(const bf & u) { return __bfloat162float(u); }
6
+ __device__ inline bf to_bf(const float & u) { return __float2bfloat16_rn(u); }
7
+
8
+ typedef bf * __restrict__ F_;
9
+
10
+ __global__ void forward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, bf* y_, float* s_, float* sa_) {
11
+ constexpr int C = _C_;
12
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
13
+
14
+ float state[C] = {0};
15
+ __shared__ float q[C], k[C], w[C], a[C], b[C];
16
+
17
+ for (int t = 0; t < T; t++) {
18
+ int ind = bb*T*H*C + t*H*C + hh * C + i;
19
+ __syncthreads();
20
+ q[i] = to_float(q_[ind]);
21
+ w[i] = __expf(-__expf(to_float(w_[ind])));
22
+ k[i] = to_float(k_[ind]);
23
+ a[i] = to_float(a_[ind]);
24
+ b[i] = to_float(b_[ind]);
25
+ __syncthreads();
26
+
27
+ float sa = 0;
28
+ #pragma unroll
29
+ for (int j = 0; j < C; j++) {
30
+ sa += a[j] * state[j];
31
+ }
32
+ sa_[ind] = sa;
33
+
34
+ float v = to_float(v_[ind]);
35
+ float y = 0;
36
+ #pragma unroll
37
+ for (int j = 0; j < C; j++) {
38
+ float& s = state[j];
39
+ s = s * w[j] + sa * b[j] + k[j] * v;
40
+ y += s * q[j];
41
+ }
42
+ y_[ind] = to_bf(y);
43
+
44
+ if ((t+1)%_CHUNK_LEN_ == 0) {
45
+ int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i;
46
+ #pragma unroll
47
+ for (int j = 0; j < C; j++) {
48
+ s_[base + j*C] = state[j];
49
+ }
50
+ }
51
+ }
52
+ }
53
+
54
+ __global__ void backward_kernel(int T, int H, F_ w_, F_ q_, F_ k_, F_ v_, F_ a_, F_ b_, F_ dy_, float * __restrict__ s_, float * __restrict__ sa_, bf* dw_, bf* dq_, bf* dk_, bf* dv_, bf* da_, bf* db_) {
55
+ constexpr int C = _C_;
56
+ int bb = blockIdx.y, hh = blockIdx.x, i = threadIdx.x;
57
+
58
+ float stateT[C] = {0}, dstate[C] = {0}, dstateT[C] = {0};
59
+ __shared__ float w[C], q[C], k[C], v[C], a[C], b[C], dy[C], sa[C], dSb_shared[C];
60
+ float qi, wi, ki, ai, bi, dyi;
61
+
62
+ for (int t = T-1; t >= 0; t--) {
63
+ int ind = bb*T*H*C + t*H*C + hh * C + i;
64
+ __syncthreads();
65
+ q[i] = qi = to_float(q_[ind]);
66
+ float wi_fac = -__expf(to_float(w_[ind]));
67
+ w[i] = wi = __expf(wi_fac);
68
+ k[i] = ki = to_float(k_[ind]);
69
+ a[i] = ai = to_float(a_[ind]);
70
+ b[i] = bi = to_float(b_[ind]);
71
+ v[i] = to_float(v_[ind]);
72
+ dy[i] = dyi = to_float(dy_[ind]);
73
+ sa[i] = sa_[ind];
74
+ __syncthreads();
75
+
76
+ if ((t+1)%_CHUNK_LEN_ == 0) {
77
+ int base = (bb*H+hh)*(T/_CHUNK_LEN_)*C*C + (t/_CHUNK_LEN_)*C*C + i*C;
78
+ #pragma unroll
79
+ for (int j = 0; j < C; j++) {
80
+ stateT[j] = s_[base + j];
81
+ }
82
+ }
83
+
84
+ float dq = 0;
85
+ #pragma unroll
86
+ for (int j = 0; j < C; j++) {
87
+ dq += stateT[j]*dy[j];
88
+ }
89
+ dq_[ind] = to_bf(dq);
90
+
91
+ float iwi = 1.0f/wi;
92
+ #pragma unroll
93
+ for (int j = 0; j < C; j++) {
94
+ stateT[j] = (stateT[j] - ki*v[j] - bi*sa[j]) * iwi;
95
+ dstate[j] += dyi * q[j];
96
+ dstateT[j] += qi * dy[j];
97
+ }
98
+
99
+ float dw = 0, dk = 0, dv = 0, db = 0, dSb = 0;
100
+ #pragma unroll
101
+ for (int j = 0; j < C; j++) {
102
+ dw += dstateT[j]*stateT[j];
103
+ dk += dstateT[j]*v[j];
104
+ dv += dstate[j]*k[j];
105
+ dSb += dstate[j]*b[j];
106
+ db += dstateT[j]*sa[j];
107
+ }
108
+ dw_[ind] = to_bf(dw * wi * wi_fac);
109
+ dk_[ind] = to_bf(dk);
110
+ dv_[ind] = to_bf(dv);
111
+ db_[ind] = to_bf(db);
112
+
113
+ __syncthreads();
114
+ dSb_shared[i] = dSb;
115
+ __syncthreads();
116
+
117
+ float da = 0;
118
+ #pragma unroll
119
+ for (int j = 0; j < C; j++) {
120
+ da += stateT[j]*dSb_shared[j];
121
+ }
122
+ da_[ind] = to_bf(da);
123
+
124
+ #pragma unroll
125
+ for (int j = 0; j < C; j++) {
126
+ dstate[j] = dstate[j]*w[j] + dSb * a[j];
127
+ dstateT[j] = dstateT[j]*wi + ai * dSb_shared[j];
128
+ }
129
+ }
130
+ }
131
+
132
+ void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa) {
133
+ forward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,y,s,sa);
134
+ }
135
+ void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da) {
136
+ assert(T%_CHUNK_LEN_ == 0);
137
+ backward_kernel<<<dim3(H,B), dim3(_C_)>>>(T,H,w,q,k,v,z,a,dy,s,sa,dw,dq,dk,dv,dz,da);
138
+ }
cuda/wkv7_op.cpp ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <cuda_bf16.h>
3
+ using bf = __nv_bfloat16;
4
+
5
+ void cuda_forward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*y, float*s, float*sa);
6
+
7
+ void forward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &y, torch::Tensor &s, torch::Tensor &sa) {
8
+ int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
9
+ cuda_forward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)y.data_ptr(), (float*)s.data_ptr(), (float*)sa.data_ptr());
10
+ }
11
+
12
+ void cuda_backward(int B, int T, int H, bf*w, bf*q, bf*k, bf*v, bf*z, bf*a, bf*dy, float*s, float*sa, bf*dw, bf*dq, bf*dk, bf*dv, bf*dz, bf*da);
13
+
14
+ void backward(torch::Tensor &w, torch::Tensor &q, torch::Tensor &k, torch::Tensor &v, torch::Tensor &z, torch::Tensor &a, torch::Tensor &dy,
15
+ torch::Tensor &s, torch::Tensor &sa, torch::Tensor &dw, torch::Tensor &dq, torch::Tensor &dk, torch::Tensor &dv, torch::Tensor &dz, torch::Tensor &da) {
16
+ int B = w.sizes()[0], T = w.sizes()[1], H = w.sizes()[2];
17
+ cuda_backward(B, T, H, (bf*)w.data_ptr(), (bf*)q.data_ptr(), (bf*)k.data_ptr(), (bf*)v.data_ptr(), (bf*)z.data_ptr(), (bf*)a.data_ptr(), (bf*)dy.data_ptr(),
18
+ (float*)s.data_ptr(), (float*)sa.data_ptr(), (bf*)dw.data_ptr(), (bf*)dq.data_ptr(), (bf*)dk.data_ptr(), (bf*)dv.data_ptr(), (bf*)dz.data_ptr(), (bf*)da.data_ptr());
19
+ }
20
+
21
+ TORCH_LIBRARY(wind_backstepping, m) {
22
+ m.def("forward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor(a!) y, Tensor(b!) s, Tensor(c!) sa) -> ()");
23
+ m.def("backward(Tensor w, Tensor q, Tensor k, Tensor v, Tensor z, Tensor a, Tensor dy, Tensor s, Tensor sa, Tensor(a!) dw, Tensor(b!) dq, Tensor(c!) dk, Tensor(d!) dv, Tensor(e!) dz, Tensor(f!) da) -> ()");
24
+ }
25
+
26
+ TORCH_LIBRARY_IMPL(wind_backstepping, CUDA, m) {
27
+ m.impl("forward", &forward);
28
+ m.impl("backward", &backward);
29
+ }
30
+
31
+ // TORCH_LIBRARY(wind_backstepping, m) {
32
+ // m.def("forward", forward);
33
+ // m.def("backward", backward);
34
+ // }
modeling_blocks_rwkv7.py ADDED
The diff for this file is too large to render. See raw diff
 
modeling_rwkv7.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ RWKV Modeling"""
2
+
3
+ from transformers.modeling_utils import PreTrainedModel
4
+ from transformers.utils import (
5
+ ModelOutput,
6
+ add_code_sample_docstrings,
7
+ add_start_docstrings,
8
+ add_start_docstrings_to_model_forward,
9
+ is_ninja_available,
10
+ is_torch_cuda_available,
11
+ logging,
12
+ )
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import ModelOutput
15
+
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import CrossEntropyLoss
19
+ import torch.nn.functional as F
20
+
21
+ import warnings
22
+ from dataclasses import dataclass
23
+ from typing import List, Dict, Optional, Tuple, Union, Any
24
+
25
+ # Load the RWKV7Config and RWKV7GooseModel
26
+ from .configuration_rwkv7 import RWKV7Config
27
+ from .modeling_blocks_rwkv7 import RWKV7GooseModel
28
+
29
+ class RWKV7PreTrainedModel(PreTrainedModel,RWKV7GooseModel):
30
+ """
31
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
32
+ """
33
+ config_class = RWKV7Config
34
+ base_model_prefix = "rwkv7"
35
+ is_parallelizable = True
36
+ _no_split_modules = ["RWKV7LayerBlock"]
37
+ _keep_in_fp32_modules = []
38
+ supports_gradient_checkpointing = True
39
+
40
+ def __init__(self, config: RWKV7Config):
41
+ RWKV7GooseModel.__init__(self, config.__dict__)
42
+ self.config = config
43
+
44
+ def _init_weights(
45
+ self,
46
+ module
47
+ ):
48
+ # Fallback to the default init weights
49
+ if hasattr(module, 'reset_parameters'):
50
+ module.reset_parameters()
51
+ return
52
+ elif hasattr(module, 'init_parameters'):
53
+ module.init_parameters()
54
+ return
55
+
56
+ # Default FP initializer_range for Linear / LN layers
57
+ initializer_range = 0.02
58
+
59
+ if isinstance(module, (nn.ParameterList, nn.ModuleList)):
60
+ # Iterate and initialize each parameter
61
+ for param in module:
62
+ self._init_weights(param)
63
+ elif isinstance(module, nn.ParameterDict):
64
+ # Iterate and initialize each parameter
65
+ for key, param in module.items():
66
+ self._init_weights(param)
67
+
68
+ elif isinstance(module, (nn.Linear, nn.Conv1d)):
69
+ # Slightly different from the TF version which uses truncated_normal for initialization
70
+ # cf https://github.com/pytorch/pytorch/pull/5617
71
+ nn.init.normal_(module.weight, mean=0.0, std=initializer_range)
72
+ if module.bias is not None:
73
+ nn.init.zeros_(module.bias)
74
+ elif isinstance(module, nn.LayerNorm):
75
+ nn.init.normal_(module.weight, mean=0.0, std=initializer_range)
76
+ elif isinstance(module, nn.Parameter):
77
+ nn.init.normal_(module, mean=0.0, std=initializer_range)
78
+ elif isinstance(module, nn.Embedding):
79
+ nn.init.normal_(module.weight, mean=0.0, std=initializer_range)
80
+
81
+ # # RWKV does not use a blank pad idx. The pad_idx is a training token
82
+ # if module.padding_idx is not None:
83
+ # module.weight.data[module.padding_idx].zero_()
84
+
85
+ @dataclass
86
+ class RWKV7Output(ModelOutput):
87
+ """
88
+ Class for the RWKV model outputs.
89
+ Args:
90
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
91
+ Sequence of hidden-states at the output of the last layer of the model.
92
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
93
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
94
+ avoid providing the old `input_ids`.
95
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
96
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
97
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
98
+ the model at the output of each layer plus the optional initial embedding outputs.
99
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
100
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
101
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
102
+ the self-attention heads.
103
+ """
104
+ last_hidden_state: torch.FloatTensor = None
105
+ rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None
106
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
107
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
108
+
109
+
110
+ @dataclass
111
+ class RWKV7CausalLMOutput(ModelOutput):
112
+ """
113
+ Base class for causal language model (or autoregressive) outputs.
114
+ Args:
115
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
116
+ Language modeling loss (for next-token prediction).
117
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
118
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
119
+ state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
120
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
121
+ avoid providing the old `input_ids`.
122
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
123
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
124
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
125
+ the model at the output of each layer plus the optional initial embedding outputs.
126
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
127
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
128
+ sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
129
+ the self-attention heads.
130
+ """
131
+ loss: Optional[torch.FloatTensor] = None
132
+ logits: torch.FloatTensor = None
133
+ rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None
134
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
135
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
136
+
137
+ RWKV7_START_DOCSTRING = r"""
138
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
139
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
140
+ etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
141
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
142
+ general usage and behavior.
143
+ Parameters:
144
+ config ([`Rwkv7Config`]): Model configuration class with all the parameters of the model.
145
+ Initializing with a config file does not load the weights associated with the model, only the
146
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
147
+ """
148
+
149
+ RWKV7_INPUTS_DOCSTRING = r"""
150
+ Args:
151
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
152
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
153
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
154
+ sequence tokens in the vocabulary. If `past_key_values` is used, only `input_ids` that do not have their
155
+ past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See
156
+ [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
157
+ IDs?](../glossary#input-ids)
158
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
159
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
160
+ - 1 for tokens that are **not masked**,
161
+ - 0 for tokens that are **masked**.
162
+ [What are attention masks?](../glossary#attention-mask)
163
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
164
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
165
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
166
+ model's internal embedding lookup matrix.
167
+
168
+ state (List block states, representing the RWKV various internal states per layer `(batch_size, hidden_state)`, *optional*):
169
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
170
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
171
+
172
+ use_cache (`bool`, *optional*):
173
+ If set to `True`, the last state is returned and can be used to quickly generate the next logits.
174
+ output_attentions (`bool`, *optional*):
175
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
176
+ tensors for more detail.
177
+ output_hidden_states (`bool`, *optional*):
178
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
179
+ more detail.
180
+
181
+ return_dict (`bool`, *optional*):
182
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
183
+ """
184
+
185
+ @add_start_docstrings(
186
+ "The bare RWKV7 Model transformer outputting raw hidden-states without activating the head (variable is still declared)",
187
+ RWKV7_START_DOCSTRING,
188
+ )
189
+ class RWKV7Model(RWKV7PreTrainedModel):
190
+ def __init__(self, config: RWKV7Config):
191
+ super().__init__(config)
192
+
193
+ def get_input_embeddings(self):
194
+ return self.emb
195
+ def set_input_embeddings(self, value):
196
+ self.emb = value
197
+
198
+ def get_output_embeddings(self):
199
+ return self.lm_head
200
+ def set_output_embeddings(self, new_embeddings):
201
+ self.lm_head = new_embeddings
202
+
203
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
204
+ @add_code_sample_docstrings(
205
+ output_type=RWKV7Output,
206
+ )
207
+ def forward(
208
+ self,
209
+ input_ids: Optional[torch.LongTensor] = None,
210
+ attention_mask: Optional[torch.LongTensor] = None, # not in use
211
+ inputs_embeds: Optional[torch.FloatTensor] = None,
212
+ rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None,
213
+ use_cache: Optional[bool] = None,
214
+ output_attentions: Optional[bool] = None,
215
+ output_hidden_states: Optional[bool] = None,
216
+ return_dict: Optional[bool] = None,
217
+ **kwargs
218
+ ) -> Union[Tuple, RWKV7Output]:
219
+
220
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
221
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
222
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
223
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
224
+
225
+ if output_attentions:
226
+ warnings.warning_once("`RWKV7Model` does not `output_attentions` now, setting it to `False`.")
227
+ output_attentions = False
228
+
229
+ if self.gradient_checkpointing and self.training and use_cache:
230
+ warnings.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
231
+ use_cache = False
232
+
233
+ if self.gradient_checkpointing and self.training and use_cache:
234
+ warnings.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
235
+ use_cache = False
236
+
237
+ if output_hidden_states:
238
+ warnings.warning_once("`RWKV7Model` does not `output_hidden_states` now, setting it to `False`.")
239
+ output_hidden_states = False
240
+
241
+ # ---
242
+
243
+ # Compute the input embeddings
244
+ if input_ids is not None and inputs_embeds is not None:
245
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
246
+ if input_ids is None and inputs_embeds is None:
247
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
248
+ if inputs_embeds is None:
249
+ inputs_embeds = self.emb(input_ids.to(self.emb.weight.device))
250
+ x_hidden_state = inputs_embeds
251
+
252
+ # Initialize the rwkv_state / prv_stateList
253
+ if rwkv_state is None or use_cache == False:
254
+ rwkv_state = self.get_init_state(batch_size=x_hidden_state.shape[0])
255
+ prv_stateList = rwkv_state
256
+
257
+ # Initialize the ret_stateList
258
+ ret_stateList = self.get_init_state(batch_size=x_hidden_state.shape[0], skip_init_state=True)
259
+
260
+ all_hidden_states = () if output_hidden_states else None
261
+ all_attns = () if output_attentions else None
262
+ v_first = None
263
+ ret_sublist = None
264
+
265
+ # Lets start iterating the blocks
266
+ for i, block in enumerate(self.blocks):
267
+ # Build the full inner hidden state
268
+ if output_hidden_states:
269
+ all_hidden_states += (x_hidden_state,)
270
+
271
+ # Forward the block
272
+ if self.gradient_checkpointing and self.training:
273
+ x_hidden_state, ret_sublist, v_first = self._gradient_checkpointing_func(
274
+ block.__call__, x_hidden_state, prv_stateList[i], v_first
275
+ )
276
+ ret_stateList[i] = ret_sublist
277
+ else:
278
+ x_hidden_state, ret_sublist, v_first = block(x_hidden_state, prv_stateList[i], v_first)
279
+ ret_stateList[i] = ret_sublist
280
+
281
+ # if output_attentions:
282
+ # all_attns += (ret_sublist,)
283
+
284
+ # Final layer norm
285
+ x_hidden_state = x_hidden_state.to(self.ln_out.weight.device, non_blocking=True)
286
+ x_hidden_state = self.ln_out(x_hidden_state)
287
+
288
+ # add hidden states from the last decoder layer
289
+ if output_hidden_states:
290
+ all_hidden_states += (x_hidden_state,)
291
+
292
+ if not return_dict:
293
+ return tuple(i for i in [x_hidden_state, rwkv_state, all_hidden_states, all_attns] if i is not None)
294
+ return RWKV7Output(
295
+ last_hidden_state=x_hidden_state,
296
+ rwkv_state=rwkv_state,
297
+ hidden_states=all_hidden_states,
298
+ attentions=all_attns
299
+ )
300
+
301
+ @add_start_docstrings(
302
+ """
303
+ The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
304
+ embeddings).
305
+ """,
306
+ RWKV7_START_DOCSTRING,
307
+ )
308
+ class RWKV7ForCausalLM(RWKV7Model, GenerationMixin):
309
+
310
+ def __init__(self, config):
311
+ super().__init__(config)
312
+ self.post_init()
313
+
314
+ def prepare_inputs_for_generation(
315
+ self,
316
+ input_ids=None,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ inputs_embeds: Optional[torch.Tensor] = None,
319
+ use_cache: bool = True,
320
+ rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None,
321
+ # num_new_tokens_if_rwkv_state: int = 1, # Only triggers if given input_ids + rwkv_state
322
+ num_logits_to_keep: Optional[int] = None,
323
+ **kwargs
324
+ ):
325
+ '''
326
+ Personal Notes: On huggingface barely documented "Transformer" hooks.
327
+
328
+ I assume this is triggered once, for the start of AI inference.
329
+ With subsequent calls for forward on each token step, being updated with
330
+ `_update_model_kwargs_for_generation` function instead?
331
+ '''
332
+ # # only last token for `inputs_ids` if the `past_key_values` is passed along.
333
+ # if rwkv_state is not None and input_ids is not None:
334
+ # input_ids = input_ids[:, -num_new_tokens_if_rwkv_state:]
335
+
336
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
337
+ if inputs_embeds is not None:
338
+ if input_ids is not None:
339
+ raise ValueError("You cannot specify both `inputs_ids` and `inputs_embeds` at the same time")
340
+ model_inputs = {'inputs_embeds': inputs_embeds}
341
+ else:
342
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
343
+ # recompiles graphs as the stride of the inputs is a guard.
344
+ # Ref: https://github.com/huggingface/transformers/pull/29114
345
+ # TODO: use `next_tokens` directly instead.
346
+ model_inputs = {'input_ids': input_ids.contiguous()}
347
+
348
+ if num_logits_to_keep is not None:
349
+ model_inputs['num_logits_to_keep'] = num_logits_to_keep
350
+
351
+ model_inputs.update({
352
+ 'rwkv_state': rwkv_state,
353
+ 'use_cache': use_cache,
354
+ 'attention_mask': attention_mask,
355
+ 'num_logits_to_keep': num_logits_to_keep,
356
+ })
357
+ return model_inputs
358
+
359
+ def _update_model_kwargs_for_generation(
360
+ self, outputs: ModelOutput,
361
+ model_kwargs: Dict[str, Any],
362
+ num_new_tokens: int = 1,
363
+ **kwargs
364
+ ) -> Dict[str, Any]:
365
+ # Overwritten -- this model uses `state`, but doesn't have a cache (`past_key_values`)
366
+ rwkv_state = outputs.get("rwkv_state", None)
367
+ input_ids = model_kwargs.get("input_ids", None)
368
+ attention_mask = model_kwargs.get("attention_mask", None)
369
+
370
+ # only last token for inputs_ids if the state is passed along.
371
+ if rwkv_state is not None and input_ids is not None and num_new_tokens > 0:
372
+ input_ids = input_ids[:, -num_new_tokens:]
373
+ model_kwargs["input_ids"] = input_ids
374
+
375
+ if attention_mask is not None:
376
+ attention_mask = attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))
377
+ model_kwargs["attention_mask"] = attention_mask
378
+
379
+ # Return the formated output
380
+ return model_kwargs
381
+
382
+ @add_start_docstrings_to_model_forward(RWKV7_INPUTS_DOCSTRING)
383
+ @add_code_sample_docstrings(
384
+ output_type=RWKV7CausalLMOutput,
385
+ )
386
+ def forward(
387
+ self,
388
+ input_ids: Optional[torch.LongTensor] = None,
389
+ attention_mask: Optional[torch.LongTensor] = None, # noqa
390
+ inputs_embeds: Optional[torch.FloatTensor] = None,
391
+ labels: Optional[torch.LongTensor] = None,
392
+ rwkv_state: Optional[list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]] = None,
393
+ use_cache: Optional[bool] = None,
394
+ output_attentions: Optional[bool] = None,
395
+ output_hidden_states: Optional[bool] = None,
396
+ return_dict: Optional[bool] = None,
397
+ **kwargs
398
+ ) -> Union[Tuple, RWKV7CausalLMOutput]:
399
+ r"""
400
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
401
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
402
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
403
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
404
+ """
405
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
406
+
407
+ rwkv_outputs = RWKV7Model.forward(
408
+ self, input_ids, attention_mask, inputs_embeds,
409
+ rwkv_state, use_cache, output_attentions, output_hidden_states,
410
+ return_dict=False
411
+ )
412
+
413
+ # Get the hidden state, and the updated RWKV state
414
+ hidden_states = rwkv_outputs[0]
415
+ rwkv_state = rwkv_outputs[1]
416
+
417
+ # Get the ALL hidden states and attentions dumps
418
+ all_hidden_states = rwkv_outputs[2] if output_hidden_states else None
419
+ if output_hidden_states:
420
+ all_attns = rwkv_outputs[3] if output_attentions else None
421
+ else:
422
+ all_attns = rwkv_outputs[2] if output_attentions else None
423
+
424
+ # Forward the head state
425
+ logits = self.head(hidden_states)
426
+
427
+ # Compute the loss from the labels
428
+ loss = None
429
+ if labels is not None:
430
+
431
+ # Setup loss function
432
+ if self._loss_function_cache is None:
433
+ self._loss_function_cache = CrossEntropyLoss()
434
+
435
+ # move labels to correct device to enable model parallelism
436
+ labels = labels.to(logits.device)
437
+ # Shift so that tokens < n predict n
438
+ shift_logits = logits[..., :-1, :].contiguous()
439
+ shift_labels = labels[..., 1:].contiguous()
440
+ # Compute the token loss
441
+
442
+ if attention_mask is not None:
443
+ token_loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="none")
444
+ submask = attention_mask[..., 1:].contiguous().view(-1)
445
+ loss = (token_loss * submask).sum() / submask.sum()
446
+ else:
447
+ loss = F.cross_entropy(shift_logits.view(-1, shift_labels.size(-1)), shift_labels.view(-1), reduction="mean")
448
+
449
+ if not return_dict:
450
+ return tuple(i for i in [loss, logits, rwkv_state, all_hidden_states, all_attns] if i is not None)
451
+
452
+ return RWKV7CausalLMOutput(
453
+ loss=loss,
454
+ logits=logits,
455
+ rwkv_state=rwkv_state,
456
+ hidden_states=all_hidden_states,
457
+ attentions=all_attns,
458
+ )
459
+
460
+ __all__ = ["RWKV7ForCausalLM", "RWKV7Model", "RWKV7PreTrainedModel"]
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "tokenizer_class": "GPTNeoXTokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff