|
import torch |
|
|
|
|
|
def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor: |
|
""" |
|
Unpack a tensor of uint8 packed bytes (two 4-bit values per byte) into a 1D tensor of int8 values, |
|
vectorized over the entire input. |
|
""" |
|
lower = packed & 0xF |
|
upper = (packed >> 4) & 0xF |
|
|
|
nibbles = torch.stack([lower, upper], dim=-1).view(-1)[:original_length] |
|
nibbles = nibbles.to(torch.int8) |
|
nibbles[nibbles >= 8] -= 16 |
|
return nibbles |
|
|
|
|
|
def dequantize_tensor( |
|
packed: torch.Tensor, |
|
scales: torch.Tensor, |
|
orig_shape: torch.Size, |
|
block_size: int, |
|
dtype: torch.dtype, |
|
): |
|
""" |
|
Dequantizes a packed int4 tensor (with given per-block scales) back to bfloat16, |
|
using vectorized operations to avoid Python loops. |
|
""" |
|
num_bytes_per_block = (block_size + 1) // 2 |
|
num_blocks_total = packed.numel() // num_bytes_per_block |
|
|
|
packed_rows = packed.view(num_blocks_total, num_bytes_per_block) |
|
|
|
|
|
lower = packed_rows & 0xF |
|
upper = (packed_rows >> 4) & 0xF |
|
|
|
nibbles = torch.stack([lower, upper], dim=2).view(num_blocks_total, -1) |
|
|
|
quantized_flat = nibbles[:, :block_size].to(torch.int8) |
|
quantized_flat[quantized_flat >= 8] -= 16 |
|
|
|
|
|
last_dim = orig_shape[-1] |
|
num_blocks = last_dim // block_size |
|
new_shape = orig_shape[:-1] + (num_blocks, block_size) |
|
quantized = quantized_flat.view(new_shape) |
|
|
|
|
|
dequantized = quantized.to(torch.float32) * scales.unsqueeze(-1) |
|
dequantized = dequantized.view(orig_shape) |
|
return dequantized.to(dtype) |
|
|