# Copyright (c) ONNX Project Contributors # # SPDX-License-Identifier: Apache-2.0 """Graph utilities for checking whether an ONNX proto message is legal.""" from __future__ import annotations __all__ = [ "check_attribute", "check_function", "check_graph", "check_model", "check_node", "check_sparse_tensor", "check_tensor", "check_value_info", "DEFAULT_CONTEXT", "LEXICAL_SCOPE_CONTEXT", "ValidationError", "C", "MAXIMUM_PROTOBUF", ] import os import sys from typing import Any, Callable, TypeVar from google.protobuf.message import Message import onnx.defs import onnx.onnx_cpp2py_export.checker as C # noqa: N812 import onnx.shape_inference from onnx import ( IR_VERSION, AttributeProto, FunctionProto, GraphProto, ModelProto, NodeProto, SparseTensorProto, TensorProto, ValueInfoProto, ) # Limitation of single protobuf file is 2GB MAXIMUM_PROTOBUF = 2000000000 # TODO: This thing where we reserialize the protobuf back into the # string, only to deserialize it at the call site, is really goofy. # Stop doing that. # NB: Please don't edit this context! DEFAULT_CONTEXT = C.CheckerContext() DEFAULT_CONTEXT.ir_version = IR_VERSION # TODO: Maybe ONNX-ML should also be defaulted? DEFAULT_CONTEXT.opset_imports = {"": onnx.defs.onnx_opset_version()} LEXICAL_SCOPE_CONTEXT = C.LexicalScopeContext() FuncType = TypeVar("FuncType", bound=Callable[..., Any]) def _ensure_proto_type(proto: Message, proto_type: type[Message]) -> None: if not isinstance(proto, proto_type): raise TypeError( f"The proto message needs to be of type '{proto_type.__name__}'" ) def check_value_info( value_info: ValueInfoProto, ctx: C.CheckerContext = DEFAULT_CONTEXT ) -> None: _ensure_proto_type(value_info, ValueInfoProto) return C.check_value_info(value_info.SerializeToString(), ctx) def check_tensor(tensor: TensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT) -> None: _ensure_proto_type(tensor, TensorProto) return C.check_tensor(tensor.SerializeToString(), ctx) def check_attribute( attr: AttributeProto, ctx: C.CheckerContext = DEFAULT_CONTEXT, lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT, ) -> None: _ensure_proto_type(attr, AttributeProto) return C.check_attribute(attr.SerializeToString(), ctx, lexical_scope_ctx) def check_node( node: NodeProto, ctx: C.CheckerContext = DEFAULT_CONTEXT, lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT, ) -> None: _ensure_proto_type(node, NodeProto) return C.check_node(node.SerializeToString(), ctx, lexical_scope_ctx) def check_function( function: FunctionProto, ctx: C.CheckerContext | None = None, lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT, ) -> None: _ensure_proto_type(function, FunctionProto) if ctx is None: ctx = C.CheckerContext() ctx.ir_version = onnx.helper.find_min_ir_version_for( function.opset_import, ignore_unknown=True ) ctx.opset_imports = { domain_version.domain: domain_version.version for domain_version in function.opset_import } C.check_function(function.SerializeToString(), ctx, lexical_scope_ctx) def check_graph( graph: GraphProto, ctx: C.CheckerContext = DEFAULT_CONTEXT, lexical_scope_ctx: C.LexicalScopeContext = LEXICAL_SCOPE_CONTEXT, ) -> None: _ensure_proto_type(graph, GraphProto) return C.check_graph(graph.SerializeToString(), ctx, lexical_scope_ctx) def check_sparse_tensor( sparse: SparseTensorProto, ctx: C.CheckerContext = DEFAULT_CONTEXT ) -> None: _ensure_proto_type(sparse, SparseTensorProto) C.check_sparse_tensor(sparse.SerializeToString(), ctx) def check_model( model: ModelProto | str | bytes | os.PathLike, full_check: bool = False, skip_opset_compatibility_check: bool = False, check_custom_domain: bool = False, ) -> None: """Check the consistency of a model. An exception will be raised if the model's ir_version is not set properly or is higher than checker's ir_version, or if the model has duplicate keys in metadata_props. If IR version >= 3, the model must specify opset_import. If IR version < 3, the model cannot have any opset_import specified. Args: model: Model to check. If model is a path, the function checks model path first. If the model bytes size is larger than 2GB, function should be called using model path. full_check: If True, the function also runs shape inference check. skip_opset_compatibility_check: If True, the function skips the check for opset compatibility. check_custom_domain: If True, the function will check all domains. Otherwise only check built-in domains. """ # If model is a path instead of ModelProto if isinstance(model, (str, os.PathLike)): C.check_model_path( os.fspath(model), full_check, skip_opset_compatibility_check, check_custom_domain, ) else: protobuf_string = ( model if isinstance(model, bytes) else model.SerializeToString() ) # If the protobuf is larger than 2GB, # remind users should use the model path to check if sys.getsizeof(protobuf_string) > MAXIMUM_PROTOBUF: raise ValueError( "This protobuf of onnx model is too large (>2GB). Call check_model with model path instead." ) C.check_model( protobuf_string, full_check, skip_opset_compatibility_check, check_custom_domain, ) ValidationError = C.ValidationError