import torch def unpack_int4(packed: torch.Tensor, original_length: int) -> torch.Tensor: """ Unpack a tensor of uint8 packed bytes (two 4-bit values per byte) into a 1D tensor of int8 values, vectorized over the entire input. """ lower = packed & 0xF upper = (packed >> 4) & 0xF # Interleave lower and upper nibbles nibbles = torch.stack([lower, upper], dim=-1).view(-1)[:original_length] nibbles = nibbles.to(torch.int8) nibbles[nibbles >= 8] -= 16 return nibbles def dequantize_tensor( packed: torch.Tensor, scales: torch.Tensor, orig_shape: torch.Size, block_size: int, dtype: torch.dtype, ): """ Dequantizes a packed int4 tensor (with given per-block scales) back to bfloat16, using vectorized operations to avoid Python loops. """ num_bytes_per_block = (block_size + 1) // 2 # number of packed bytes per block num_blocks_total = packed.numel() // num_bytes_per_block # Reshape to (num_blocks_total, num_bytes_per_block) packed_rows = packed.view(num_blocks_total, num_bytes_per_block) # Vectorized unpacking: compute lower and upper nibbles for all rows at once. lower = packed_rows & 0xF upper = (packed_rows >> 4) & 0xF # Create a new dimension for the two nibbles and then flatten. nibbles = torch.stack([lower, upper], dim=2).view(num_blocks_total, -1) # Slice to get exactly block_size values per block. quantized_flat = nibbles[:, :block_size].to(torch.int8) quantized_flat[quantized_flat >= 8] -= 16 # Reshape to original block structure. last_dim = orig_shape[-1] num_blocks = last_dim // block_size new_shape = orig_shape[:-1] + (num_blocks, block_size) quantized = quantized_flat.view(new_shape) # Dequantize using scales. dequantized = quantized.to(torch.float32) * scales.unsqueeze(-1) dequantized = dequantized.view(orig_shape) return dequantized.to(dtype)