File size: 1,345 Bytes
9b4ed9c f089fb8 9b4ed9c f089fb8 9b4ed9c f089fb8 9b4ed9c f089fb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor:
orig_shape = packed.shape
last_dim = orig_shape[-1]
batch_shape = orig_shape[:-1]
flat_packed = packed.reshape(-1, last_dim)
batch_size = flat_packed.shape[0]
flat_bytes = flat_packed.reshape(-1)
lower = flat_bytes & 0xF
upper = (flat_bytes >> 4) & 0xF
unpacked = torch.stack([lower, upper], dim=1).reshape(batch_size, last_dim * 2)
unpacked = unpacked[:, :original_length]
unpacked = unpacked.reshape(*batch_shape, original_length)
return unpacked.to(torch.int8)
def dequantize_tensor(
packed: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
orig_shape: torch.Size,
block_size: int,
dtype: torch.dtype = torch.bfloat16,
):
out_features, num_blocks, _ = packed.shape
unpacked = unpack_int4(packed, block_size)
scales_view = scales.unsqueeze(2) # Shape: [out_features, num_blocks, 1]
zero_points_view = zero_points.unsqueeze(2) # Shape: [out_features, num_blocks, 1]
dequantized = (unpacked.float() - zero_points_view) * scales_view
dequantized = dequantized.reshape(out_features, num_blocks * block_size)
dequantized = dequantized[:, : orig_shape[1]]
dequantized = dequantized.reshape(orig_shape)
return dequantized.to(dtype)
|