Spaces:
Paused
Paused
| # Copyright (c) ONNX Project Contributors | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import warnings | |
| from typing import Any, Dict, NamedTuple, Union, cast | |
| import numpy as np | |
| from onnx import OptionalProto, SequenceProto, TensorProto | |
| class TensorDtypeMap(NamedTuple): | |
| np_dtype: np.dtype | |
| storage_dtype: int | |
| name: str | |
| # tensor_dtype: (numpy type, storage type, string name) | |
| TENSOR_TYPE_MAP = { | |
| int(TensorProto.FLOAT): TensorDtypeMap( | |
| np.dtype("float32"), int(TensorProto.FLOAT), "TensorProto.FLOAT" | |
| ), | |
| int(TensorProto.UINT8): TensorDtypeMap( | |
| np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT8" | |
| ), | |
| int(TensorProto.INT8): TensorDtypeMap( | |
| np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT8" | |
| ), | |
| int(TensorProto.UINT16): TensorDtypeMap( | |
| np.dtype("uint16"), int(TensorProto.INT32), "TensorProto.UINT16" | |
| ), | |
| int(TensorProto.INT16): TensorDtypeMap( | |
| np.dtype("int16"), int(TensorProto.INT32), "TensorProto.INT16" | |
| ), | |
| int(TensorProto.INT32): TensorDtypeMap( | |
| np.dtype("int32"), int(TensorProto.INT32), "TensorProto.INT32" | |
| ), | |
| int(TensorProto.INT64): TensorDtypeMap( | |
| np.dtype("int64"), int(TensorProto.INT64), "TensorProto.INT64" | |
| ), | |
| int(TensorProto.BOOL): TensorDtypeMap( | |
| np.dtype("bool"), int(TensorProto.INT32), "TensorProto.BOOL" | |
| ), | |
| int(TensorProto.FLOAT16): TensorDtypeMap( | |
| np.dtype("float16"), int(TensorProto.UINT16), "TensorProto.FLOAT16" | |
| ), | |
| # Native numpy does not support bfloat16 so now use float32. | |
| int(TensorProto.BFLOAT16): TensorDtypeMap( | |
| np.dtype("float32"), int(TensorProto.UINT16), "TensorProto.BFLOAT16" | |
| ), | |
| int(TensorProto.DOUBLE): TensorDtypeMap( | |
| np.dtype("float64"), int(TensorProto.DOUBLE), "TensorProto.DOUBLE" | |
| ), | |
| int(TensorProto.COMPLEX64): TensorDtypeMap( | |
| np.dtype("complex64"), int(TensorProto.FLOAT), "TensorProto.COMPLEX64" | |
| ), | |
| int(TensorProto.COMPLEX128): TensorDtypeMap( | |
| np.dtype("complex128"), int(TensorProto.DOUBLE), "TensorProto.COMPLEX128" | |
| ), | |
| int(TensorProto.UINT32): TensorDtypeMap( | |
| np.dtype("uint32"), int(TensorProto.UINT32), "TensorProto.UINT32" | |
| ), | |
| int(TensorProto.UINT64): TensorDtypeMap( | |
| np.dtype("uint64"), int(TensorProto.UINT64), "TensorProto.UINT64" | |
| ), | |
| int(TensorProto.STRING): TensorDtypeMap( | |
| np.dtype("object"), int(TensorProto.STRING), "TensorProto.STRING" | |
| ), | |
| # Native numpy does not support float8 types, so now use float32 for these types. | |
| int(TensorProto.FLOAT8E4M3FN): TensorDtypeMap( | |
| np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FN" | |
| ), | |
| int(TensorProto.FLOAT8E4M3FNUZ): TensorDtypeMap( | |
| np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E4M3FNUZ" | |
| ), | |
| int(TensorProto.FLOAT8E5M2): TensorDtypeMap( | |
| np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2" | |
| ), | |
| int(TensorProto.FLOAT8E5M2FNUZ): TensorDtypeMap( | |
| np.dtype("float32"), int(TensorProto.UINT8), "TensorProto.FLOAT8E5M2FNUZ" | |
| ), | |
| # Native numpy does not support uint4/int4 so now use uint8/int8 for these types. | |
| int(TensorProto.UINT4): TensorDtypeMap( | |
| np.dtype("uint8"), int(TensorProto.INT32), "TensorProto.UINT4" | |
| ), | |
| int(TensorProto.INT4): TensorDtypeMap( | |
| np.dtype("int8"), int(TensorProto.INT32), "TensorProto.INT4" | |
| ), | |
| } | |
| class DeprecatedWarningDict(dict): # type: ignore | |
| def __init__( | |
| self, | |
| dictionary: Dict[int, Union[int, str, np.dtype]], | |
| original_function: str, | |
| future_function: str = "", | |
| ) -> None: | |
| super().__init__(dictionary) | |
| self._origin_function = original_function | |
| self._future_function = future_function | |
| def __eq__(self, other: object) -> bool: | |
| if not isinstance(other, DeprecatedWarningDict): | |
| return False | |
| return ( | |
| self._origin_function == other._origin_function | |
| and self._future_function == other._future_function | |
| ) | |
| def __getitem__(self, key: Union[int, str, np.dtype]) -> Any: | |
| if not self._future_function: | |
| warnings.warn( | |
| str( | |
| f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release." | |
| "To silence this warning, please simply use if-else statement to get the corresponding value." | |
| ), | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| else: | |
| warnings.warn( | |
| str( | |
| f"`mapping.{self._origin_function}` is now deprecated and will be removed in a future release." | |
| f"To silence this warning, please use `helper.{self._future_function}` instead." | |
| ), | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| return super().__getitem__(key) | |
| # This map is used for converting TensorProto values into numpy arrays | |
| TENSOR_TYPE_TO_NP_TYPE = DeprecatedWarningDict( | |
| {tensor_dtype: value.np_dtype for tensor_dtype, value in TENSOR_TYPE_MAP.items()}, | |
| "TENSOR_TYPE_TO_NP_TYPE", | |
| "tensor_dtype_to_np_dtype", | |
| ) | |
| # This is only used to get keys into STORAGE_TENSOR_TYPE_TO_FIELD. | |
| # TODO(https://github.com/onnx/onnx/issues/4554): Move these variables into _mapping.py | |
| TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE = DeprecatedWarningDict( | |
| { | |
| tensor_dtype: value.storage_dtype | |
| for tensor_dtype, value in TENSOR_TYPE_MAP.items() | |
| }, | |
| "TENSOR_TYPE_TO_STORAGE_TENSOR_TYPE", | |
| "tensor_dtype_to_storage_tensor_dtype", | |
| ) | |
| # NP_TYPE_TO_TENSOR_TYPE will be eventually removed in the future | |
| # and _NP_TYPE_TO_TENSOR_TYPE will only be used internally | |
| _NP_TYPE_TO_TENSOR_TYPE = { | |
| v: k | |
| for k, v in TENSOR_TYPE_TO_NP_TYPE.items() | |
| if k | |
| not in ( | |
| TensorProto.BFLOAT16, | |
| TensorProto.FLOAT8E4M3FN, | |
| TensorProto.FLOAT8E4M3FNUZ, | |
| TensorProto.FLOAT8E5M2, | |
| TensorProto.FLOAT8E5M2FNUZ, | |
| TensorProto.UINT4, | |
| TensorProto.INT4, | |
| ) | |
| } | |
| # Currently native numpy does not support bfloat16 so TensorProto.BFLOAT16 is ignored for now | |
| # Numpy float32 array is only reversed to TensorProto.FLOAT | |
| NP_TYPE_TO_TENSOR_TYPE = DeprecatedWarningDict( | |
| cast(Dict[int, Union[int, str, Any]], _NP_TYPE_TO_TENSOR_TYPE), | |
| "NP_TYPE_TO_TENSOR_TYPE", | |
| "np_dtype_to_tensor_dtype", | |
| ) | |
| # STORAGE_TENSOR_TYPE_TO_FIELD will be eventually removed in the future | |
| # and _STORAGE_TENSOR_TYPE_TO_FIELD will only be used internally | |
| _STORAGE_TENSOR_TYPE_TO_FIELD = { | |
| int(TensorProto.FLOAT): "float_data", | |
| int(TensorProto.INT32): "int32_data", | |
| int(TensorProto.INT64): "int64_data", | |
| int(TensorProto.UINT8): "int32_data", | |
| int(TensorProto.UINT16): "int32_data", | |
| int(TensorProto.DOUBLE): "double_data", | |
| int(TensorProto.COMPLEX64): "float_data", | |
| int(TensorProto.COMPLEX128): "double_data", | |
| int(TensorProto.UINT32): "uint64_data", | |
| int(TensorProto.UINT64): "uint64_data", | |
| int(TensorProto.STRING): "string_data", | |
| int(TensorProto.BOOL): "int32_data", | |
| } | |
| STORAGE_TENSOR_TYPE_TO_FIELD = DeprecatedWarningDict( | |
| cast(Dict[int, Union[int, str, Any]], _STORAGE_TENSOR_TYPE_TO_FIELD), | |
| "STORAGE_TENSOR_TYPE_TO_FIELD", | |
| ) | |
| # This map will be removed and there is no replacement for it | |
| STORAGE_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict( | |
| { | |
| int(SequenceProto.TENSOR): "tensor_values", | |
| int(SequenceProto.SPARSE_TENSOR): "sparse_tensor_values", | |
| int(SequenceProto.SEQUENCE): "sequence_values", | |
| int(SequenceProto.MAP): "map_values", | |
| int(OptionalProto.OPTIONAL): "optional_value", | |
| }, | |
| "STORAGE_ELEMENT_TYPE_TO_FIELD", | |
| ) | |
| # This map will be removed and there is no replacement for it | |
| OPTIONAL_ELEMENT_TYPE_TO_FIELD = DeprecatedWarningDict( | |
| { | |
| int(OptionalProto.TENSOR): "tensor_value", | |
| int(OptionalProto.SPARSE_TENSOR): "sparse_tensor_value", | |
| int(OptionalProto.SEQUENCE): "sequence_value", | |
| int(OptionalProto.MAP): "map_value", | |
| int(OptionalProto.OPTIONAL): "optional_value", | |
| }, | |
| "OPTIONAL_ELEMENT_TYPE_TO_FIELD", | |
| ) | |