Upload HfMoondream
Browse files- layers.py +6 -2
- model.safetensors +2 -2
- packing.py +26 -34
layers.py
CHANGED
@@ -43,7 +43,10 @@ class QuantizedLinear(nn.Module):
|
|
43 |
),
|
44 |
requires_grad=False,
|
45 |
),
|
46 |
-
"
|
|
|
|
|
|
|
47 |
torch.empty(out_features, in_features // 128), requires_grad=False
|
48 |
),
|
49 |
}
|
@@ -55,7 +58,8 @@ class QuantizedLinear(nn.Module):
|
|
55 |
self.weight = nn.Parameter(
|
56 |
dequantize_tensor(
|
57 |
self.weight["packed"],
|
58 |
-
self.weight["
|
|
|
59 |
(self.weight["packed"].shape[0], self.weight["packed"].shape[1] * 128),
|
60 |
128,
|
61 |
torch.bfloat16,
|
|
|
43 |
),
|
44 |
requires_grad=False,
|
45 |
),
|
46 |
+
"scale": nn.Parameter(
|
47 |
+
torch.empty(out_features, in_features // 128), requires_grad=False
|
48 |
+
),
|
49 |
+
"zero_point": nn.Parameter(
|
50 |
torch.empty(out_features, in_features // 128), requires_grad=False
|
51 |
),
|
52 |
}
|
|
|
58 |
self.weight = nn.Parameter(
|
59 |
dequantize_tensor(
|
60 |
self.weight["packed"],
|
61 |
+
self.weight["scale"],
|
62 |
+
self.weight["zero_point"],
|
63 |
(self.weight["packed"].shape[0], self.weight["packed"].shape[1] * 128),
|
64 |
128,
|
65 |
torch.bfloat16,
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:325876fadb939f7c65f545d5d37b03f5035681b87bad1073f6d2e804ce2f4068
|
3 |
+
size 1881750512
|
packing.py
CHANGED
@@ -1,43 +1,35 @@
|
|
1 |
import torch
|
2 |
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
def dequantize_tensor(
|
5 |
packed: torch.Tensor,
|
6 |
scales: torch.Tensor,
|
|
|
7 |
orig_shape: torch.Size,
|
8 |
block_size: int,
|
9 |
-
dtype: torch.dtype,
|
10 |
):
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
pr = packed.view(num_blocks, num_bytes)
|
21 |
-
|
22 |
-
# prepare output in the target dtype
|
23 |
-
out = torch.empty((num_blocks, block_size), device=packed.device, dtype=dtype)
|
24 |
-
|
25 |
-
# ---- lower nibble ----
|
26 |
-
lower = pr & 0xF # [blocks, bytes]
|
27 |
-
lower = lower.to(torch.int8) # cast to signed
|
28 |
-
lower[lower >= 8] -= 16 # sign-correct
|
29 |
-
|
30 |
-
lo_count = (block_size + 1) // 2
|
31 |
-
out[:, 0:block_size:2] = lower[:, :lo_count].to(dtype) * scales.view(-1, 1)
|
32 |
-
|
33 |
-
# ---- upper nibble ----
|
34 |
-
pr >>= 4 # in-place shift of the original packed bytes
|
35 |
-
upper = pr & 0xF
|
36 |
-
upper = upper.to(torch.int8)
|
37 |
-
upper[upper >= 8] -= 16
|
38 |
-
|
39 |
-
hi_count = block_size // 2
|
40 |
-
out[:, 1:block_size:2] = upper[:, :hi_count].to(dtype) * scales.view(-1, 1)
|
41 |
-
|
42 |
-
# restore original shape
|
43 |
-
return out.view(orig_shape)
|
|
|
1 |
import torch
|
2 |
|
3 |
|
4 |
+
def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor:
|
5 |
+
orig_shape = packed.shape
|
6 |
+
last_dim = orig_shape[-1]
|
7 |
+
batch_shape = orig_shape[:-1]
|
8 |
+
flat_packed = packed.reshape(-1, last_dim)
|
9 |
+
batch_size = flat_packed.shape[0]
|
10 |
+
flat_bytes = flat_packed.reshape(-1)
|
11 |
+
lower = flat_bytes & 0xF
|
12 |
+
upper = (flat_bytes >> 4) & 0xF
|
13 |
+
unpacked = torch.stack([lower, upper], dim=1).reshape(batch_size, last_dim * 2)
|
14 |
+
unpacked = unpacked[:, :original_length]
|
15 |
+
unpacked = unpacked.reshape(*batch_shape, original_length)
|
16 |
+
return unpacked.to(torch.int8)
|
17 |
+
|
18 |
+
|
19 |
def dequantize_tensor(
|
20 |
packed: torch.Tensor,
|
21 |
scales: torch.Tensor,
|
22 |
+
zero_points: torch.Tensor,
|
23 |
orig_shape: torch.Size,
|
24 |
block_size: int,
|
25 |
+
dtype: torch.dtype = torch.bfloat16,
|
26 |
):
|
27 |
+
out_features, num_blocks, _ = packed.shape
|
28 |
+
unpacked = unpack_int4(packed, block_size)
|
29 |
+
scales_view = scales.unsqueeze(2) # Shape: [out_features, num_blocks, 1]
|
30 |
+
zero_points_view = zero_points.unsqueeze(2) # Shape: [out_features, num_blocks, 1]
|
31 |
+
dequantized = (unpacked.float() - zero_points_view) * scales_view
|
32 |
+
dequantized = dequantized.reshape(out_features, num_blocks * block_size)
|
33 |
+
dequantized = dequantized[:, : orig_shape[1]]
|
34 |
+
dequantized = dequantized.reshape(orig_shape)
|
35 |
+
return dequantized.to(dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|