vikhyatk commited on
Commit
28e93ab
·
verified ·
1 Parent(s): 1ccf5fd

Upload HfMoondream

Browse files
Files changed (1) hide show
  1. packing.py +31 -40
packing.py CHANGED
@@ -1,20 +1,6 @@
1
  import torch
2
 
3
 
4
- def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor:
5
- """
6
- Unpack a tensor of uint8 packed bytes (two 4-bit values per byte) into a 1D tensor of int8 values,
7
- vectorized over the entire input.
8
- """
9
- lower = packed & 0xF
10
- upper = (packed >> 4) & 0xF
11
- # Interleave lower and upper nibbles
12
- nibbles = torch.stack([lower, upper], dim=-1).view(-1)[:original_length]
13
- nibbles = nibbles.to(torch.int8)
14
- nibbles[nibbles >= 8] -= 16
15
- return nibbles
16
-
17
-
18
  def dequantize_tensor(
19
  packed: torch.Tensor,
20
  scales: torch.Tensor,
@@ -23,30 +9,35 @@ def dequantize_tensor(
23
  dtype: torch.dtype,
24
  ):
25
  """
26
- Dequantizes a packed int4 tensor (with given per-block scales) back to bfloat16,
27
- using vectorized operations to avoid Python loops.
28
  """
29
- num_bytes_per_block = (block_size + 1) // 2 # number of packed bytes per block
30
- num_blocks_total = packed.numel() // num_bytes_per_block
31
- # Reshape to (num_blocks_total, num_bytes_per_block)
32
- packed_rows = packed.view(num_blocks_total, num_bytes_per_block)
33
-
34
- # Vectorized unpacking: compute lower and upper nibbles for all rows at once.
35
- lower = packed_rows & 0xF
36
- upper = (packed_rows >> 4) & 0xF
37
- # Create a new dimension for the two nibbles and then flatten.
38
- nibbles = torch.stack([lower, upper], dim=2).view(num_blocks_total, -1)
39
- # Slice to get exactly block_size values per block.
40
- quantized_flat = nibbles[:, :block_size].to(torch.int8)
41
- quantized_flat[quantized_flat >= 8] -= 16
42
-
43
- # Reshape to original block structure.
44
- last_dim = orig_shape[-1]
45
- num_blocks = last_dim // block_size
46
- new_shape = orig_shape[:-1] + (num_blocks, block_size)
47
- quantized = quantized_flat.view(new_shape)
48
-
49
- # Dequantize using scales.
50
- dequantized = quantized.to(torch.float32) * scales.unsqueeze(-1)
51
- dequantized = dequantized.view(orig_shape)
52
- return dequantized.to(dtype)
 
 
 
 
 
 
1
  import torch
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def dequantize_tensor(
5
  packed: torch.Tensor,
6
  scales: torch.Tensor,
 
9
  dtype: torch.dtype,
10
  ):
11
  """
12
+ In-place–friendly dequantization of int4-packed data back to `dtype`,
13
+ mutating `packed` (and reading `scales`) to avoid extra big intermediates.
14
  """
15
+ # how many bytes encode each block of `block_size` 4-bit values
16
+ num_bytes = (block_size + 1) // 2
17
+ num_blocks = packed.numel() // num_bytes
18
+
19
+ # view as [blocks, bytes_per_block]
20
+ pr = packed.view(num_blocks, num_bytes)
21
+
22
+ # prepare output in the target dtype
23
+ out = torch.empty((num_blocks, block_size), device=packed.device, dtype=dtype)
24
+
25
+ # ---- lower nibble ----
26
+ lower = pr & 0xF # [blocks, bytes]
27
+ lower = lower.to(torch.int8) # cast to signed
28
+ lower[lower >= 8] -= 16 # sign-correct
29
+
30
+ lo_count = (block_size + 1) // 2
31
+ out[:, 0:block_size:2] = lower[:, :lo_count].to(dtype) * scales.view(-1, 1)
32
+
33
+ # ---- upper nibble ----
34
+ pr >>= 4 # in-place shift of the original packed bytes
35
+ upper = pr & 0xF
36
+ upper = upper.to(torch.int8)
37
+ upper[upper >= 8] -= 16
38
+
39
+ hi_count = block_size // 2
40
+ out[:, 1:block_size:2] = upper[:, :hi_count].to(dtype) * scales.view(-1, 1)
41
+
42
+ # restore original shape
43
+ return out.view(orig_shape)