File size: 2,385 Bytes
4135502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# -*- coding: utf-8 -*-

from .abc import chunk_abc
from .attn import parallel_attn, parallel_rectified_attn, parallel_softpick_attn, naive_attn, naive_rectified_attn, naive_softpick_attn
from .based import fused_chunk_based, parallel_based
from .delta_rule import chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule
from .forgetting_attn import parallel_forgetting_attn
from .gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
from .generalized_delta_rule import (
    chunk_dplr_delta_rule,
    chunk_iplr_delta_rule,
    fused_recurrent_dplr_delta_rule,
    fused_recurrent_iplr_delta_rule
)
from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
from .gsa import chunk_gsa, fused_recurrent_gsa
from .hgrn import fused_recurrent_hgrn
from .lightning_attn import chunk_lightning_attn, fused_recurrent_lightning_attn
from .linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn
from .nsa import parallel_nsa
from .retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention
from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6
from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
from .simple_gla import chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla

__all__ = [
    'chunk_abc',
    'parallel_attn', 'parallel_rectified_attn', 'parallel_softpick_attn',
    'naive_attn', 'naive_rectified_attn', 'naive_softpick_attn',
    'fused_chunk_based', 'parallel_based',
    'chunk_delta_rule', 'fused_chunk_delta_rule', 'fused_recurrent_delta_rule',
    'parallel_forgetting_attn',
    'chunk_gated_delta_rule', 'fused_recurrent_gated_delta_rule',
    'chunk_dplr_delta_rule', 'chunk_iplr_delta_rule',
    'fused_recurrent_dplr_delta_rule', 'fused_recurrent_iplr_delta_rule',
    'chunk_gla', 'fused_chunk_gla', 'fused_recurrent_gla',
    'chunk_gsa', 'fused_recurrent_gsa',
    'fused_recurrent_hgrn',
    'chunk_lightning_attn', 'fused_recurrent_lightning_attn',
    'chunk_linear_attn', 'fused_chunk_linear_attn', 'fused_recurrent_linear_attn',
    'parallel_nsa',
    'chunk_retention', 'fused_chunk_retention', 'fused_recurrent_retention', 'parallel_retention',
    'chunk_rwkv6', 'fused_recurrent_rwkv6',
    'chunk_rwkv7', 'fused_recurrent_rwkv7',
    'chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla',
]