Upload HfMoondream
Browse files
text.py
CHANGED
@@ -28,18 +28,24 @@ def attn(
|
|
28 |
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
29 |
q_dim = n_heads * head_dim
|
30 |
kv_dim = n_kv_heads * head_dim
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
45 |
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|
|
|
28 |
qkv_out = w.qkv(x) # shape: (bsz, q_len, (n_heads + 2*n_kv_heads)*head_dim)
|
29 |
q_dim = n_heads * head_dim
|
30 |
kv_dim = n_kv_heads * head_dim
|
31 |
+
q, k, v = qkv_out.split([q_dim, kv_dim, kv_dim], dim=-1)
|
32 |
+
del qkv_out
|
33 |
+
|
34 |
+
q = q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
35 |
+
k = k.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
36 |
+
v = v.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
37 |
+
|
38 |
+
# q = qkv_out[..., :q_dim].view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
39 |
+
# k = (
|
40 |
+
# qkv_out[..., q_dim : q_dim + kv_dim]
|
41 |
+
# .view(bsz, q_len, n_kv_heads, head_dim)
|
42 |
+
# .transpose(1, 2)
|
43 |
+
# )
|
44 |
+
# v = (
|
45 |
+
# qkv_out[..., q_dim + kv_dim :]
|
46 |
+
# .view(bsz, q_len, n_kv_heads, head_dim)
|
47 |
+
# .transpose(1, 2)
|
48 |
+
# )
|
49 |
|
50 |
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
|
51 |
k = apply_rotary_emb(k, freqs_cis, position_ids, n_kv_heads)
|