vikhyatk commited on
Commit
f089fb8
·
verified ·
1 Parent(s): 48640e9

Upload HfMoondream

Browse files
Files changed (3) hide show
  1. layers.py +6 -2
  2. model.safetensors +2 -2
  3. packing.py +26 -34
layers.py CHANGED
@@ -43,7 +43,10 @@ class QuantizedLinear(nn.Module):
43
  ),
44
  requires_grad=False,
45
  ),
46
- "scales": nn.Parameter(
 
 
 
47
  torch.empty(out_features, in_features // 128), requires_grad=False
48
  ),
49
  }
@@ -55,7 +58,8 @@ class QuantizedLinear(nn.Module):
55
  self.weight = nn.Parameter(
56
  dequantize_tensor(
57
  self.weight["packed"],
58
- self.weight["scales"],
 
59
  (self.weight["packed"].shape[0], self.weight["packed"].shape[1] * 128),
60
  128,
61
  torch.bfloat16,
 
43
  ),
44
  requires_grad=False,
45
  ),
46
+ "scale": nn.Parameter(
47
+ torch.empty(out_features, in_features // 128), requires_grad=False
48
+ ),
49
+ "zero_point": nn.Parameter(
50
  torch.empty(out_features, in_features // 128), requires_grad=False
51
  ),
52
  }
 
58
  self.weight = nn.Parameter(
59
  dequantize_tensor(
60
  self.weight["packed"],
61
+ self.weight["scale"],
62
+ self.weight["zero_point"],
63
  (self.weight["packed"].shape[0], self.weight["packed"].shape[1] * 128),
64
  128,
65
  torch.bfloat16,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:97076df1a9a09ff4108a69ea59b4c9abf522b248e8425c9334bab98ddbaf4b33
3
- size 1838828672
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:325876fadb939f7c65f545d5d37b03f5035681b87bad1073f6d2e804ce2f4068
3
+ size 1881750512
packing.py CHANGED
@@ -1,43 +1,35 @@
1
  import torch
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def dequantize_tensor(
5
  packed: torch.Tensor,
6
  scales: torch.Tensor,
 
7
  orig_shape: torch.Size,
8
  block_size: int,
9
- dtype: torch.dtype,
10
  ):
11
- """
12
- In-place–friendly dequantization of int4-packed data back to `dtype`,
13
- mutating `packed` (and reading `scales`) to avoid extra big intermediates.
14
- """
15
- # how many bytes encode each block of `block_size` 4-bit values
16
- num_bytes = (block_size + 1) // 2
17
- num_blocks = packed.numel() // num_bytes
18
-
19
- # view as [blocks, bytes_per_block]
20
- pr = packed.view(num_blocks, num_bytes)
21
-
22
- # prepare output in the target dtype
23
- out = torch.empty((num_blocks, block_size), device=packed.device, dtype=dtype)
24
-
25
- # ---- lower nibble ----
26
- lower = pr & 0xF # [blocks, bytes]
27
- lower = lower.to(torch.int8) # cast to signed
28
- lower[lower >= 8] -= 16 # sign-correct
29
-
30
- lo_count = (block_size + 1) // 2
31
- out[:, 0:block_size:2] = lower[:, :lo_count].to(dtype) * scales.view(-1, 1)
32
-
33
- # ---- upper nibble ----
34
- pr >>= 4 # in-place shift of the original packed bytes
35
- upper = pr & 0xF
36
- upper = upper.to(torch.int8)
37
- upper[upper >= 8] -= 16
38
-
39
- hi_count = block_size // 2
40
- out[:, 1:block_size:2] = upper[:, :hi_count].to(dtype) * scales.view(-1, 1)
41
-
42
- # restore original shape
43
- return out.view(orig_shape)
 
1
  import torch
2
 
3
 
4
+ def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor:
5
+ orig_shape = packed.shape
6
+ last_dim = orig_shape[-1]
7
+ batch_shape = orig_shape[:-1]
8
+ flat_packed = packed.reshape(-1, last_dim)
9
+ batch_size = flat_packed.shape[0]
10
+ flat_bytes = flat_packed.reshape(-1)
11
+ lower = flat_bytes & 0xF
12
+ upper = (flat_bytes >> 4) & 0xF
13
+ unpacked = torch.stack([lower, upper], dim=1).reshape(batch_size, last_dim * 2)
14
+ unpacked = unpacked[:, :original_length]
15
+ unpacked = unpacked.reshape(*batch_shape, original_length)
16
+ return unpacked.to(torch.int8)
17
+
18
+
19
  def dequantize_tensor(
20
  packed: torch.Tensor,
21
  scales: torch.Tensor,
22
+ zero_points: torch.Tensor,
23
  orig_shape: torch.Size,
24
  block_size: int,
25
+ dtype: torch.dtype = torch.bfloat16,
26
  ):
27
+ out_features, num_blocks, _ = packed.shape
28
+ unpacked = unpack_int4(packed, block_size)
29
+ scales_view = scales.unsqueeze(2) # Shape: [out_features, num_blocks, 1]
30
+ zero_points_view = zero_points.unsqueeze(2) # Shape: [out_features, num_blocks, 1]
31
+ dequantized = (unpacked.float() - zero_points_view) * scales_view
32
+ dequantized = dequantized.reshape(out_features, num_blocks * block_size)
33
+ dequantized = dequantized[:, : orig_shape[1]]
34
+ dequantized = dequantized.reshape(orig_shape)
35
+ return dequantized.to(dtype)