vikhyatk's picture
Upload HfMoondream
9b4ed9c verified
raw
history blame
1.95 kB
import torch
def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor:
"""
Unpack a tensor of uint8 packed bytes (two 4-bit values per byte) into a 1D tensor of int8 values,
vectorized over the entire input.
"""
lower = packed & 0xF
upper = (packed >> 4) & 0xF
# Interleave lower and upper nibbles
nibbles = torch.stack([lower, upper], dim=-1).view(-1)[:original_length]
nibbles = nibbles.to(torch.int8)
nibbles[nibbles >= 8] -= 16
return nibbles
def dequantize_tensor(
packed: torch.Tensor,
scales: torch.Tensor,
orig_shape: torch.Size,
block_size: int,
dtype: torch.dtype,
):
"""
Dequantizes a packed int4 tensor (with given per-block scales) back to bfloat16,
using vectorized operations to avoid Python loops.
"""
num_bytes_per_block = (block_size + 1) // 2 # number of packed bytes per block
num_blocks_total = packed.numel() // num_bytes_per_block
# Reshape to (num_blocks_total, num_bytes_per_block)
packed_rows = packed.view(num_blocks_total, num_bytes_per_block)
# Vectorized unpacking: compute lower and upper nibbles for all rows at once.
lower = packed_rows & 0xF
upper = (packed_rows >> 4) & 0xF
# Create a new dimension for the two nibbles and then flatten.
nibbles = torch.stack([lower, upper], dim=2).view(num_blocks_total, -1)
# Slice to get exactly block_size values per block.
quantized_flat = nibbles[:, :block_size].to(torch.int8)
quantized_flat[quantized_flat >= 8] -= 16
# Reshape to original block structure.
last_dim = orig_shape[-1]
num_blocks = last_dim // block_size
new_shape = orig_shape[:-1] + (num_blocks, block_size)
quantized = quantized_flat.view(new_shape)
# Dequantize using scales.
dequantized = quantized.to(torch.float32) * scales.unsqueeze(-1)
dequantized = dequantized.view(orig_shape)
return dequantized.to(dtype)