Spaces:
Running
Running
| # coding=utf-8 | |
| # Copyright 2022 The Google Research Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Helper routines for quantization.""" | |
| from typing import Any | |
| import chex | |
| import jax.numpy as jnp | |
| from flax import struct | |
| # pylint:disable=no-value-for-parameter | |
| class QuantizedValue: | |
| """State associated with quantized value.""" | |
| quantized: chex.Array | |
| diagonal: chex.Array # Diagonal (if extract_diagonal is set) | |
| bucket_size: chex.Array | |
| quantized_dtype: jnp.dtype = struct.field( | |
| pytree_node=False | |
| ) # Dtype for the quantized value. | |
| extract_diagonal: bool = struct.field(pytree_node=False) # In case its centered. | |
| shape: Any = struct.field(pytree_node=False) # Shape of the tensor. | |
| def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False): | |
| if isinstance(fvalue, list) and not fvalue: | |
| return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, []) | |
| quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize( | |
| fvalue, quantized_dtype, extract_diagonal | |
| ) | |
| return QuantizedValue( | |
| quantized, | |
| diagonal_fvalue, | |
| bucket_size, | |
| quantized_dtype, | |
| extract_diagonal, | |
| list(quantized.shape), | |
| ) | |
| # Quantization is from Lingvo JAX optimizers. | |
| # We extend it for int16 quantization of PSD matrices. | |
| def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False): | |
| """Returns quantized value and the bucket.""" | |
| if quantized_dtype == jnp.float32: | |
| return fvalue, [], [] | |
| elif quantized_dtype == jnp.bfloat16: | |
| return fvalue.astype(jnp.bfloat16), [], [] | |
| float_dtype = fvalue.dtype | |
| if quantized_dtype == jnp.int8: | |
| # value -128 is not used. | |
| num_buckets = jnp.array(127.0, dtype=float_dtype) | |
| elif quantized_dtype == jnp.int16: | |
| # value -32768 is not used. | |
| num_buckets = jnp.array(32767.0, dtype=float_dtype) | |
| else: | |
| raise ValueError(f"Quantized dtype {quantized_dtype} not supported.") | |
| # max value is mapped to num_buckets | |
| if extract_diagonal and fvalue.ndim != 2: | |
| raise ValueError( | |
| f"Input array {fvalue} must be 2D to work with extract_diagonal." | |
| ) | |
| diagonal_fvalue = [] | |
| if extract_diagonal: | |
| diagonal_fvalue = jnp.diag(fvalue) | |
| # Remove the diagonal entries. | |
| fvalue = fvalue - jnp.diag(diagonal_fvalue) | |
| # TODO(rohananil): Extend this by making use of information about the blocks | |
| # SM3 style which will be useful for diagonal statistics | |
| # We first decide the scale. | |
| if fvalue.ndim < 1: | |
| raise ValueError( | |
| f"Input array {fvalue} must have a strictly positive number of " | |
| "dimensions." | |
| ) | |
| max_abs = jnp.max(jnp.abs(fvalue), axis=0) | |
| bucket_size = max_abs / num_buckets | |
| bs_expanded = bucket_size[jnp.newaxis, Ellipsis] | |
| # To avoid divide by 0.0 | |
| bs_nonzero = jnp.where( | |
| bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) | |
| ) | |
| ratio = fvalue / bs_nonzero | |
| # We use rounding to remove bias. | |
| quantized = jnp.round(ratio) | |
| return quantized.astype(quantized_dtype), diagonal_fvalue, bucket_size | |
| def to_float(self): | |
| """Returns the float value.""" | |
| if isinstance(self.quantized, list) and not self.quantized: | |
| return self.quantized | |
| if self.quantized_dtype == jnp.float32: | |
| return self.quantized | |
| if self.quantized_dtype == jnp.bfloat16: | |
| return self.quantized.astype(jnp.float32) | |
| float_dtype = self.bucket_size.dtype | |
| bucket_size = self.bucket_size[jnp.newaxis, Ellipsis] | |
| val = self.quantized.astype(float_dtype) * bucket_size | |
| if self.extract_diagonal: | |
| val += jnp.diag(self.diagonal) | |
| return val | |