|
import torch |
|
|
|
|
|
def dequantize_tensor( |
|
packed: torch.Tensor, |
|
scales: torch.Tensor, |
|
orig_shape: torch.Size, |
|
block_size: int, |
|
dtype: torch.dtype, |
|
): |
|
""" |
|
In-place–friendly dequantization of int4-packed data back to `dtype`, |
|
mutating `packed` (and reading `scales`) to avoid extra big intermediates. |
|
""" |
|
|
|
num_bytes = (block_size + 1) // 2 |
|
num_blocks = packed.numel() // num_bytes |
|
|
|
|
|
pr = packed.view(num_blocks, num_bytes) |
|
|
|
|
|
out = torch.empty((num_blocks, block_size), device=packed.device, dtype=dtype) |
|
|
|
|
|
lower = pr & 0xF |
|
lower = lower.to(torch.int8) |
|
lower[lower >= 8] -= 16 |
|
|
|
lo_count = (block_size + 1) // 2 |
|
out[:, 0:block_size:2] = lower[:, :lo_count].to(dtype) * scales.view(-1, 1) |
|
|
|
|
|
pr >>= 4 |
|
upper = pr & 0xF |
|
upper = upper.to(torch.int8) |
|
upper[upper >= 8] -= 16 |
|
|
|
hi_count = block_size // 2 |
|
out[:, 1:block_size:2] = upper[:, :hi_count].to(dtype) * scales.view(-1, 1) |
|
|
|
|
|
return out.view(orig_shape) |
|
|