Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): | |
| mantissa_scaled = torch.where( | |
| normal_mask, | |
| (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), | |
| (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) | |
| ) | |
| mantissa_scaled += torch.rand(mantissa_scaled.size(), dtype=mantissa_scaled.dtype, layout=mantissa_scaled.layout, device=mantissa_scaled.device, generator=generator) | |
| return mantissa_scaled.floor() / (2**MANTISSA_BITS) | |
| #Not 100% sure about this | |
| def manual_stochastic_round_to_float8(x, dtype, generator=None): | |
| if dtype == torch.float8_e4m3fn: | |
| EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7 | |
| elif dtype == torch.float8_e5m2: | |
| EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15 | |
| else: | |
| raise ValueError("Unsupported dtype") | |
| x = x.half() | |
| sign = torch.sign(x) | |
| abs_x = x.abs() | |
| sign = torch.where(abs_x == 0, 0, sign) | |
| # Combine exponent calculation and clamping | |
| exponent = torch.clamp( | |
| torch.floor(torch.log2(abs_x)) + EXPONENT_BIAS, | |
| 0, 2**EXPONENT_BITS - 1 | |
| ) | |
| # Combine mantissa calculation and rounding | |
| normal_mask = ~(exponent == 0) | |
| abs_x[:] = calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=generator) | |
| sign *= torch.where( | |
| normal_mask, | |
| (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x), | |
| (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x | |
| ) | |
| inf = torch.finfo(dtype) | |
| torch.clamp(sign, min=inf.min, max=inf.max, out=sign) | |
| return sign | |
| def stochastic_rounding(value, dtype, seed=0): | |
| if dtype == torch.float32: | |
| return value.to(dtype=torch.float32) | |
| if dtype == torch.float16: | |
| return value.to(dtype=torch.float16) | |
| if dtype == torch.bfloat16: | |
| return value.to(dtype=torch.bfloat16) | |
| if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: | |
| #generator = torch.Generator(device='cuda' if torch.cuda.is_available() else 'cpu') | |
| torch.manual_seed(seed) | |
| if(torch.cuda.is_available()): | |
| torch.cuda.manual_seed(seed) | |
| output = torch.empty_like(value, dtype=dtype) | |
| num_slices = max(1, (value.numel() / (4096 * 4096))) | |
| slice_size = max(1, round(value.shape[0] / num_slices)) | |
| with torch.no_grad(): | |
| for i in range(0, value.shape[0], slice_size): | |
| output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype)) | |
| return output | |
| return value.to(dtype=dtype) | |