Spaces:
Running
Running
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
from __future__ import annotations | |
__all__ = [ | |
# Constants | |
"ONNX_ML", | |
"IR_VERSION", | |
"IR_VERSION_2017_10_10", | |
"IR_VERSION_2017_10_30", | |
"IR_VERSION_2017_11_3", | |
"IR_VERSION_2019_1_22", | |
"IR_VERSION_2019_3_18", | |
"IR_VERSION_2019_9_19", | |
"IR_VERSION_2020_5_8", | |
"IR_VERSION_2021_7_30", | |
"IR_VERSION_2023_5_5", | |
"EXPERIMENTAL", | |
"STABLE", | |
# Modules | |
"checker", | |
"compose", | |
"defs", | |
"gen_proto", | |
"helper", | |
"hub", | |
"mapping", | |
"numpy_helper", | |
"parser", | |
"printer", | |
"shape_inference", | |
"utils", | |
"version_converter", | |
# Proto classes | |
"AttributeProto", | |
"FunctionProto", | |
"GraphProto", | |
"MapProto", | |
"ModelProto", | |
"NodeProto", | |
"OperatorProto", | |
"OperatorSetIdProto", | |
"OperatorSetProto", | |
"OperatorStatus", | |
"OptionalProto", | |
"SequenceProto", | |
"SparseTensorProto", | |
"StringStringEntryProto", | |
"TensorAnnotation", | |
"TensorProto", | |
"TensorShapeProto", | |
"TrainingInfoProto", | |
"TypeProto", | |
"ValueInfoProto", | |
"Version", | |
# Utility functions | |
"convert_model_to_external_data", | |
"load_external_data_for_model", | |
"load_model_from_string", | |
"load_model", | |
"load_tensor_from_string", | |
"load_tensor", | |
"save_model", | |
"save_tensor", | |
"write_external_data_tensors", | |
] | |
# isort:skip_file | |
import os | |
import typing | |
from typing import IO, Literal, Union | |
from onnx import serialization | |
from onnx.onnx_cpp2py_export import ONNX_ML | |
from onnx.external_data_helper import ( | |
load_external_data_for_model, | |
write_external_data_tensors, | |
convert_model_to_external_data, | |
) | |
from onnx.onnx_pb import ( | |
AttributeProto, | |
EXPERIMENTAL, | |
FunctionProto, | |
GraphProto, | |
IR_VERSION, | |
IR_VERSION_2017_10_10, | |
IR_VERSION_2017_10_30, | |
IR_VERSION_2017_11_3, | |
IR_VERSION_2019_1_22, | |
IR_VERSION_2019_3_18, | |
IR_VERSION_2019_9_19, | |
IR_VERSION_2020_5_8, | |
IR_VERSION_2021_7_30, | |
IR_VERSION_2023_5_5, | |
ModelProto, | |
NodeProto, | |
OperatorSetIdProto, | |
OperatorStatus, | |
STABLE, | |
SparseTensorProto, | |
StringStringEntryProto, | |
TensorAnnotation, | |
TensorProto, | |
TensorShapeProto, | |
TrainingInfoProto, | |
TypeProto, | |
ValueInfoProto, | |
Version, | |
) | |
from onnx.onnx_operators_pb import OperatorProto, OperatorSetProto | |
from onnx.onnx_data_pb import MapProto, OptionalProto, SequenceProto | |
from onnx.version import version as __version__ | |
# Import common subpackages so they're available when you 'import onnx' | |
from onnx import ( | |
checker, | |
compose, | |
defs, | |
gen_proto, | |
helper, | |
hub, | |
mapping, | |
numpy_helper, | |
parser, | |
printer, | |
shape_inference, | |
utils, | |
version_converter, | |
) | |
# Supported model formats that can be loaded from and saved to | |
# The literals are formats with built-in support. But we also allow users to | |
# register their own formats. So we allow str as well. | |
_SupportedFormat = Union[Literal["protobuf", "textproto"], str] | |
# Default serialization format | |
_DEFAULT_FORMAT = "protobuf" | |
def _load_bytes(f: IO[bytes] | str | os.PathLike) -> bytes: | |
if hasattr(f, "read") and callable(typing.cast(IO[bytes], f).read): | |
content = typing.cast(IO[bytes], f).read() | |
else: | |
f = typing.cast(Union[str, os.PathLike], f) | |
with open(f, "rb") as readable: | |
content = readable.read() | |
return content | |
def _save_bytes(content: bytes, f: IO[bytes] | str | os.PathLike) -> None: | |
if hasattr(f, "write") and callable(typing.cast(IO[bytes], f).write): | |
typing.cast(IO[bytes], f).write(content) | |
else: | |
f = typing.cast(Union[str, os.PathLike], f) | |
with open(f, "wb") as writable: | |
writable.write(content) | |
def _get_file_path(f: IO[bytes] | str | os.PathLike | None) -> str | None: | |
if isinstance(f, (str, os.PathLike)): | |
return os.path.abspath(f) | |
if hasattr(f, "name"): | |
assert f is not None | |
return os.path.abspath(f.name) | |
return None | |
def _get_serializer( | |
fmt: _SupportedFormat | None, f: str | os.PathLike | IO[bytes] | None = None | |
) -> serialization.ProtoSerializer: | |
"""Get the serializer for the given path and format from the serialization registry.""" | |
# Use fmt if it is specified | |
if fmt is not None: | |
return serialization.registry.get(fmt) | |
if (file_path := _get_file_path(f)) is not None: | |
_, ext = os.path.splitext(file_path) | |
fmt = serialization.registry.get_format_from_file_extension(ext) | |
# Failed to resolve format if fmt is None. Use protobuf as default | |
fmt = fmt or _DEFAULT_FORMAT | |
assert fmt is not None | |
return serialization.registry.get(fmt) | |
def load_model( | |
f: IO[bytes] | str | os.PathLike, | |
format: _SupportedFormat | None = None, # noqa: A002 | |
load_external_data: bool = True, | |
) -> ModelProto: | |
"""Loads a serialized ModelProto into memory. | |
Args: | |
f: can be a file-like object (has "read" function) or a string/PathLike containing a file name | |
format: The serialization format. When it is not specified, it is inferred | |
from the file extension when ``f`` is a path. If not specified _and_ | |
``f`` is not a path, 'protobuf' is used. The encoding is assumed to | |
be "utf-8" when the format is a text format. | |
load_external_data: Whether to load the external data. | |
Set to True if the data is under the same directory of the model. | |
If not, users need to call :func:`load_external_data_for_model` | |
with directory to load external data from. | |
Returns: | |
Loaded in-memory ModelProto. | |
""" | |
model = _get_serializer(format, f).deserialize_proto(_load_bytes(f), ModelProto()) | |
if load_external_data: | |
model_filepath = _get_file_path(f) | |
if model_filepath: | |
base_dir = os.path.dirname(model_filepath) | |
load_external_data_for_model(model, base_dir) | |
return model | |
def load_tensor( | |
f: IO[bytes] | str | os.PathLike, | |
format: _SupportedFormat | None = None, # noqa: A002 | |
) -> TensorProto: | |
"""Loads a serialized TensorProto into memory. | |
Args: | |
f: can be a file-like object (has "read" function) or a string/PathLike containing a file name | |
format: The serialization format. When it is not specified, it is inferred | |
from the file extension when ``f`` is a path. If not specified _and_ | |
``f`` is not a path, 'protobuf' is used. The encoding is assumed to | |
be "utf-8" when the format is a text format. | |
Returns: | |
Loaded in-memory TensorProto. | |
""" | |
return _get_serializer(format, f).deserialize_proto(_load_bytes(f), TensorProto()) | |
def load_model_from_string( | |
s: bytes | str, | |
format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002 | |
) -> ModelProto: | |
"""Loads a binary string (bytes) that contains serialized ModelProto. | |
Args: | |
s: a string, which contains serialized ModelProto | |
format: The serialization format. When it is not specified, it is inferred | |
from the file extension when ``f`` is a path. If not specified _and_ | |
``f`` is not a path, 'protobuf' is used. The encoding is assumed to | |
be "utf-8" when the format is a text format. | |
Returns: | |
Loaded in-memory ModelProto. | |
""" | |
return _get_serializer(format).deserialize_proto(s, ModelProto()) | |
def load_tensor_from_string( | |
s: bytes, | |
format: _SupportedFormat = _DEFAULT_FORMAT, # noqa: A002 | |
) -> TensorProto: | |
"""Loads a binary string (bytes) that contains serialized TensorProto. | |
Args: | |
s: a string, which contains serialized TensorProto | |
format: The serialization format. When it is not specified, it is inferred | |
from the file extension when ``f`` is a path. If not specified _and_ | |
``f`` is not a path, 'protobuf' is used. The encoding is assumed to | |
be "utf-8" when the format is a text format. | |
Returns: | |
Loaded in-memory TensorProto. | |
""" | |
return _get_serializer(format).deserialize_proto(s, TensorProto()) | |
def save_model( | |
proto: ModelProto | bytes, | |
f: IO[bytes] | str | os.PathLike, | |
format: _SupportedFormat | None = None, # noqa: A002 | |
*, | |
save_as_external_data: bool = False, | |
all_tensors_to_one_file: bool = True, | |
location: str | None = None, | |
size_threshold: int = 1024, | |
convert_attribute: bool = False, | |
) -> None: | |
"""Saves the ModelProto to the specified path and optionally, serialize tensors with raw data as external data before saving. | |
Args: | |
proto: should be a in-memory ModelProto | |
f: can be a file-like object (has "write" function) or a string containing | |
a file name or a pathlike object | |
format: The serialization format. When it is not specified, it is inferred | |
from the file extension when ``f`` is a path. If not specified _and_ | |
``f`` is not a path, 'protobuf' is used. The encoding is assumed to | |
be "utf-8" when the format is a text format. | |
save_as_external_data: If true, save tensors to external file(s). | |
all_tensors_to_one_file: Effective only if save_as_external_data is True. | |
If true, save all tensors to one external file specified by location. | |
If false, save each tensor to a file named with the tensor name. | |
location: Effective only if save_as_external_data is true. | |
Specify the external file that all tensors to save to. | |
Path is relative to the model path. | |
If not specified, will use the model name. | |
size_threshold: Effective only if save_as_external_data is True. | |
Threshold for size of data. Only when tensor's data is >= the size_threshold it will be converted | |
to external data. To convert every tensor with raw data to external data set size_threshold=0. | |
convert_attribute: Effective only if save_as_external_data is True. | |
If true, convert all tensors to external data | |
If false, convert only non-attribute tensors to external data | |
""" | |
if isinstance(proto, bytes): | |
proto = _get_serializer(_DEFAULT_FORMAT).deserialize_proto(proto, ModelProto()) | |
if save_as_external_data: | |
convert_model_to_external_data( | |
proto, all_tensors_to_one_file, location, size_threshold, convert_attribute | |
) | |
model_filepath = _get_file_path(f) | |
if model_filepath is not None: | |
basepath = os.path.dirname(model_filepath) | |
proto = write_external_data_tensors(proto, basepath) | |
serialized = _get_serializer(format, model_filepath).serialize_proto(proto) | |
_save_bytes(serialized, f) | |
def save_tensor( | |
proto: TensorProto, | |
f: IO[bytes] | str | os.PathLike, | |
format: _SupportedFormat | None = None, # noqa: A002 | |
) -> None: | |
"""Saves the TensorProto to the specified path. | |
Args: | |
proto: should be a in-memory TensorProto | |
f: can be a file-like object (has "write" function) or a string | |
containing a file name or a pathlike object. | |
format: The serialization format. When it is not specified, it is inferred | |
from the file extension when ``f`` is a path. If not specified _and_ | |
``f`` is not a path, 'protobuf' is used. The encoding is assumed to | |
be "utf-8" when the format is a text format. | |
""" | |
serialized = _get_serializer(format, f).serialize_proto(proto) | |
_save_bytes(serialized, f) | |
# For backward compatibility | |
load = load_model | |
load_from_string = load_model_from_string | |
save = save_model | |