vikhyatk commited on
Commit
78c093c
·
verified ·
1 Parent(s): 28e93ab

Upload HfMoondream

Browse files
Files changed (1) hide show
  1. text.py +18 -12
text.py CHANGED
@@ -28,18 +28,24 @@ def attn(
28
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
29
  q_dim = n_heads * head_dim
30
  kv_dim = n_kv_heads * head_dim
31
-
32
- q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
33
- k = (
34
- qkv_out[..., q_dim : q_dim + kv_dim]
35
- .view(bsz, q_len, n_kv_heads, head_dim)
36
- .transpose(1, 2)
37
- )
38
- v = (
39
- qkv_out[..., q_dim + kv_dim :]
40
- .view(bsz, q_len, n_kv_heads, head_dim)
41
- .transpose(1, 2)
42
- )
 
 
 
 
 
 
43
 
44
  q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
45
  k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
 
28
  qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
29
  q_dim = n_heads * head_dim
30
  kv_dim = n_kv_heads * head_dim
31
+ q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
32
+ del qkv_out
33
+
34
+ q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
35
+ k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
36
+ v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
37
+
38
+ # q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
39
+ # k = (
40
+ # qkv_out[..., q_dim : q_dim + kv_dim]
41
+ # .view(bsz, q_len, n_kv_heads, head_dim)
42
+ # .transpose(1, 2)
43
+ # )
44
+ # v = (
45
+ # qkv_out[..., q_dim + kv_dim :]
46
+ # .view(bsz, q_len, n_kv_heads, head_dim)
47
+ # .transpose(1, 2)
48
+ # )
49
 
50
  q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
51
  k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)