Spaces:
Paused
Paused
| # SPDX-License-Identifier: Apache-2.0 | |
| # Copyright (c) ONNX Project Contributors | |
| import unittest | |
| import onnx | |
| from onnx import checker, utils | |
| class TestFunction(unittest.TestCase): | |
| def _verify_function_set(self, extracted_model, function_set, func_domain): # type: ignore | |
| checker.check_model(extracted_model) | |
| self.assertEqual(len(extracted_model.functions), len(function_set)) | |
| for function in function_set: | |
| self.assertIsNotNone( | |
| next( | |
| ( | |
| f | |
| for f in extracted_model.functions | |
| if f.name == function and f.domain == func_domain | |
| ), | |
| None, | |
| ) | |
| ) | |
| def test_extract_model_with_local_function(self) -> None: | |
| r"""# 1. build a model with graph below. extract models with output combinations | |
| # 2. validate extracted models' local functions | |
| # | |
| # model graph: | |
| # i0 i1 i2 | |
| # | __________________|__________________/_________ | |
| # | | | | / | | |
| # | | | | / | | |
| # func_add func_identity add identity | |
| # | ___\___________\____________________|_________ | | |
| # | | \ \ | _______|___| | |
| # | | \ \ | | | | | |
| # add function_nested_identity_add add function_nested_identity_add | |
| # | | | | | |
| # | | | | | |
| # o_func_add o_all_func0 o_no_func o_all_func1 | |
| # | |
| # where function_nested_identity_add is a function that is defined with functions: | |
| # a b | |
| # | | | |
| # func_identity func_identity | |
| # \ / | |
| # func_add | |
| # | | |
| # c | |
| # | |
| """ | |
| # function common | |
| func_domain = "local" | |
| func_opset_imports = [onnx.helper.make_opsetid("", 14)] | |
| func_nested_opset_imports = [ | |
| onnx.helper.make_opsetid("", 14), | |
| onnx.helper.make_opsetid(func_domain, 1), | |
| ] | |
| # add function | |
| func_add_name = "func_add" | |
| func_add_inputs = ["a", "b"] | |
| func_add_outputs = ["c"] | |
| func_add_nodes = [onnx.helper.make_node("Add", ["a", "b"], ["c"])] | |
| func_add = onnx.helper.make_function( | |
| func_domain, | |
| func_add_name, | |
| func_add_inputs, | |
| func_add_outputs, | |
| func_add_nodes, | |
| func_opset_imports, | |
| ) | |
| # identity function | |
| func_identity_name = "func_identity" | |
| func_identity_inputs = ["a"] | |
| func_identity_outputs = ["b"] | |
| func_identity_nodes = [onnx.helper.make_node("Identity", ["a"], ["b"])] | |
| func_identity = onnx.helper.make_function( | |
| func_domain, | |
| func_identity_name, | |
| func_identity_inputs, | |
| func_identity_outputs, | |
| func_identity_nodes, | |
| func_opset_imports, | |
| ) | |
| # nested identity/add function | |
| func_nested_identity_add_name = "func_nested_identity_add" | |
| func_nested_identity_add_inputs = ["a", "b"] | |
| func_nested_identity_add_outputs = ["c"] | |
| func_nested_identity_add_nodes = [ | |
| onnx.helper.make_node("func_identity", ["a"], ["a1"], domain=func_domain), | |
| onnx.helper.make_node("func_identity", ["b"], ["b1"], domain=func_domain), | |
| onnx.helper.make_node("func_add", ["a1", "b1"], ["c"], domain=func_domain), | |
| ] | |
| func_nested_identity_add = onnx.helper.make_function( | |
| func_domain, | |
| func_nested_identity_add_name, | |
| func_nested_identity_add_inputs, | |
| func_nested_identity_add_outputs, | |
| func_nested_identity_add_nodes, | |
| func_nested_opset_imports, | |
| ) | |
| # create graph nodes | |
| node_func_add = onnx.helper.make_node( | |
| func_add_name, ["i0", "i1"], ["t0"], domain=func_domain | |
| ) | |
| node_add0 = onnx.helper.make_node("Add", ["i1", "i2"], ["t2"]) | |
| node_add1 = onnx.helper.make_node("Add", ["t0", "t2"], ["o_func_add"]) | |
| node_func_identity = onnx.helper.make_node( | |
| func_identity_name, ["i1"], ["t1"], domain=func_domain | |
| ) | |
| node_identity = onnx.helper.make_node("Identity", ["i1"], ["t3"]) | |
| node_add2 = onnx.helper.make_node("Add", ["t3", "t2"], ["o_no_func"]) | |
| node_func_nested0 = onnx.helper.make_node( | |
| func_nested_identity_add_name, | |
| ["t0", "t1"], | |
| ["o_all_func0"], | |
| domain=func_domain, | |
| ) | |
| node_func_nested1 = onnx.helper.make_node( | |
| func_nested_identity_add_name, | |
| ["t3", "t2"], | |
| ["o_all_func1"], | |
| domain=func_domain, | |
| ) | |
| graph_name = "graph_with_imbedded_functions" | |
| ir_version = 8 | |
| opset_imports = [ | |
| onnx.helper.make_opsetid("", 14), | |
| onnx.helper.make_opsetid("local", 1), | |
| ] | |
| tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=2, shape=[5]) | |
| graph = onnx.helper.make_graph( | |
| [ | |
| node_func_add, | |
| node_add0, | |
| node_add1, | |
| node_func_identity, | |
| node_identity, | |
| node_func_nested0, | |
| node_func_nested1, | |
| node_add2, | |
| ], | |
| graph_name, | |
| [ | |
| onnx.helper.make_value_info(name="i0", type_proto=tensor_type_proto), | |
| onnx.helper.make_value_info(name="i1", type_proto=tensor_type_proto), | |
| onnx.helper.make_value_info(name="i2", type_proto=tensor_type_proto), | |
| ], | |
| [ | |
| onnx.helper.make_value_info( | |
| name="o_no_func", type_proto=tensor_type_proto | |
| ), | |
| onnx.helper.make_value_info( | |
| name="o_func_add", type_proto=tensor_type_proto | |
| ), | |
| onnx.helper.make_value_info( | |
| name="o_all_func0", type_proto=tensor_type_proto | |
| ), | |
| onnx.helper.make_value_info( | |
| name="o_all_func1", type_proto=tensor_type_proto | |
| ), | |
| ], | |
| ) | |
| meta = { | |
| "ir_version": ir_version, | |
| "opset_imports": opset_imports, | |
| "producer_name": "test_extract_model_with_local_function", | |
| "functions": [func_identity, func_add, func_nested_identity_add], | |
| } | |
| model = onnx.helper.make_model(graph, **meta) | |
| checker.check_model(model) | |
| extracted_with_no_funcion = utils.Extractor(model).extract_model( | |
| ["i0", "i1", "i2"], ["o_no_func"] | |
| ) | |
| self._verify_function_set(extracted_with_no_funcion, {}, func_domain) | |
| extracted_with_add_funcion = utils.Extractor(model).extract_model( | |
| ["i0", "i1", "i2"], ["o_func_add"] | |
| ) | |
| self._verify_function_set( | |
| extracted_with_add_funcion, {func_add_name}, func_domain | |
| ) | |
| extracted_with_o_all_funcion0 = utils.Extractor(model).extract_model( | |
| ["i0", "i1", "i2"], ["o_all_func0"] | |
| ) | |
| self._verify_function_set( | |
| extracted_with_o_all_funcion0, | |
| {func_add_name, func_identity_name, func_nested_identity_add_name}, | |
| func_domain, | |
| ) | |
| extracted_with_o_all_funcion1 = utils.Extractor(model).extract_model( | |
| ["i0", "i1", "i2"], ["o_all_func1"] | |
| ) | |
| self._verify_function_set( | |
| extracted_with_o_all_funcion1, | |
| {func_add_name, func_identity_name, func_nested_identity_add_name}, | |
| func_domain, | |
| ) | |
| extracted_with_o_all_funcion2 = utils.Extractor(model).extract_model( | |
| ["i0", "i1", "i2"], | |
| ["o_no_func", "o_func_add", "o_all_func0", "o_all_func1"], | |
| ) | |
| self._verify_function_set( | |
| extracted_with_o_all_funcion2, | |
| {func_add_name, func_identity_name, func_nested_identity_add_name}, | |
| func_domain, | |
| ) | |
| if __name__ == "__main__": | |
| unittest.main() | |