File size: 1,948 Bytes
9b4ed9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch


def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor:
    """
    Unpack a tensor of uint8 packed bytes (two 4-bit values per byte) into a 1D tensor of int8 values,
    vectorized over the entire input.
    """
    lower = packed & 0xF
    upper = (packed >> 4) & 0xF
    # Interleave lower and upper nibbles
    nibbles = torch.stack([lower, upper], dim=-1).view(-1)[:original_length]
    nibbles = nibbles.to(torch.int8)
    nibbles[nibbles >= 8] -= 16
    return nibbles


def dequantize_tensor(
    packed: torch.Tensor,
    scales: torch.Tensor,
    orig_shape: torch.Size,
    block_size: int,
    dtype: torch.dtype,
):
    """
    Dequantizes a packed int4 tensor (with given per-block scales) back to bfloat16,
    using vectorized operations to avoid Python loops.
    """
    num_bytes_per_block = (block_size + 1) // 2  # number of packed bytes per block
    num_blocks_total = packed.numel() // num_bytes_per_block
    # Reshape to (num_blocks_total, num_bytes_per_block)
    packed_rows = packed.view(num_blocks_total, num_bytes_per_block)

    # Vectorized unpacking: compute lower and upper nibbles for all rows at once.
    lower = packed_rows & 0xF
    upper = (packed_rows >> 4) & 0xF
    # Create a new dimension for the two nibbles and then flatten.
    nibbles = torch.stack([lower, upper], dim=2).view(num_blocks_total, -1)
    # Slice to get exactly block_size values per block.
    quantized_flat = nibbles[:, :block_size].to(torch.int8)
    quantized_flat[quantized_flat >= 8] -= 16

    # Reshape to original block structure.
    last_dim = orig_shape[-1]
    num_blocks = last_dim // block_size
    new_shape = orig_shape[:-1] + (num_blocks, block_size)
    quantized = quantized_flat.view(new_shape)

    # Dequantize using scales.
    dequantized = quantized.to(torch.float32) * scales.unsqueeze(-1)
    dequantized = dequantized.view(orig_shape)
    return dequantized.to(dtype)