File size: 3,124 Bytes
c447b04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import List
from transformers import PretrainedConfig


class YingLongConfig(PretrainedConfig):
    model_type = "yinglong"
    # keys_to_ignore_at_inference = ["past_key_values"]

    def __init__(
        self,
        # input_token_len: int = 1,
        # hidden_size: int = 1024,
        # intermediate_size: int = 2048,
        # output_token_lens: List[int] = [1, 8, 32, 64],
        # num_hidden_layers: int = 8,
        # num_attention_heads: int = 8,
        # hidden_act: str = "silu",
        # use_cache: bool = True,
        # rope_theta: int = 10000,
        # attention_dropout: float = 0.0,
        # initializer_range: float = 0.02,
        # max_position_embeddings: int = 10000,
        #####
        bias = False,
        condense_ratio = 1,
        haar_trans = True,
        haar_trans_inv = True,
        haar_trans_norm = 'backward',
        half_diff = False,
        intermediate_size = 1024,
        n_embd = 256,
        n_head = 16,
        n_layer = 6,
        n_query_groups = 4,
        norm_eps = 1e-5,
        org = 'Alibaba',
        patch_size = 32,
        rope_base = 10000,
        rotary_percentage = 1.0,
        shared_attention_norm = False,
        unet = True,
        _mlp_class = "LLaMAMLP",
        _norm_class="FusedRMSNorm",
        *args,
        **kwargs,
    ):
        
        # self.input_token_len = input_token_len
        # self.hidden_size = hidden_size
        # self.intermediate_size = intermediate_size
        # self.num_hidden_layers = num_hidden_layers
        # self.num_attention_heads = num_attention_heads
        # self.hidden_act = hidden_act
        # self.output_token_lens = output_token_lens;
        # self.use_cache = use_cache
        # self.rope_theta = rope_theta
        # self.attention_dropout = attention_dropout
        # self.initializer_range = initializer_range
        # self.max_position_embeddings = max_position_embeddings
        self.org = 'Alibaba'
        self.patch_size = patch_size
        self.unet = unet
        
        self.n_embd = n_embd
        self.intermediate_size = intermediate_size
        self.n_head = n_head
        self.n_layer = n_layer
        self.n_query_groups = n_query_groups
        self.norm_eps = norm_eps
        self.bias = bias
        self.shared_attention_norm = shared_attention_norm
        
        self.condense_ratio = condense_ratio
        self.rope_base = rope_base
        self.rotary_percentage = rotary_percentage
        
        self.haar_trans = haar_trans
        self.haar_trans_inv = haar_trans_inv
        self.haar_trans_norm = haar_trans_norm
        self.half_diff = half_diff
        
        self._norm_class = _norm_class
        
        self._mlp_class = _mlp_class
        
        assert self.n_embd % self.n_head == 0
        assert self.n_head % self.n_query_groups == 0

        self.head_size = self.n_embd //  self.n_head 
        self.rope_n_elem = int(self.rotary_percentage * self.head_size)
        self.rope_condense_ratio = self.condense_ratio
        
        
        
        


        super().__init__(
            **kwargs,
        )