Spaces:
Paused
Paused
| # Copyright (c) ONNX Project Contributors | |
| # SPDX-License-Identifier: Apache-2.0 | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| from onnx import ( | |
| AttributeProto, | |
| FunctionProto, | |
| GraphProto, | |
| ModelProto, | |
| NodeProto, | |
| SparseTensorProto, | |
| TensorProto, | |
| ) | |
| from onnx.helper import ( | |
| make_attribute, | |
| make_function, | |
| make_graph, | |
| make_model, | |
| make_node, | |
| make_tensor, | |
| make_tensor_value_info, | |
| set_model_props, | |
| tensor_dtype_to_np_dtype, | |
| ) | |
| from onnx.numpy_helper import from_array | |
| def _replace_constant( | |
| node: NodeProto, threshold: int, value_constant_of_shape: float | |
| ) -> List[NodeProto]: | |
| """Replaces a Constant node with a large tensor (with more than threshold elements) by a sequence of nodes that produces a dummy constant of same shape as original tensor.""" | |
| if node.op_type != "Constant": | |
| raise TypeError(f"Node type must be 'Constant' not {node.op_type!r}.") | |
| for att in node.attribute: | |
| if att.name == "sparse_value": | |
| raise NotImplementedError( | |
| f"This feature is not yet implemented for a sparse constant " | |
| f"(node name={node.name!r})." | |
| ) | |
| if att.name == "value": | |
| value = att.t | |
| new_name = f"{value.name}__SHAPE" | |
| dims = value.dims | |
| size = np.prod(dims) | |
| if size <= threshold: | |
| return [node] | |
| init = from_array(np.array(list(dims), dtype=np.int64), name=new_name) | |
| dtype = tensor_dtype_to_np_dtype(value.data_type) | |
| node_shape = make_node( | |
| "Constant", | |
| [], | |
| [new_name], | |
| value=init, | |
| ) | |
| new_node = make_node( | |
| "ConstantOfShape", | |
| [new_name], | |
| node.output, | |
| value=from_array(np.array([value_constant_of_shape], dtype=dtype)), | |
| ) | |
| return [node_shape, new_node] | |
| raise NotImplementedError( | |
| f"Replacement of constant with attribute {att.name!r}" | |
| ) | |
| return [node] | |
| def _replace_constant_of_shape_with_range( | |
| onx: Union[GraphProto, FunctionProto] | |
| ) -> Union[GraphProto, FunctionProto]: | |
| """Replaces all *ConstantOfShape* by node *Range* to avoid constant tensors. | |
| The function is not recursive. The recursivity is done by | |
| *replace_initializer_by_constant_of_shape*. | |
| """ | |
| if isinstance(onx, GraphProto): | |
| nodes = list(onx.node) | |
| elif isinstance(onx, FunctionProto): | |
| nodes = list(onx.node) | |
| else: | |
| raise TypeError(f"Not implemented for type {type(onx)}.") | |
| existing_names = set() | |
| for node in nodes: | |
| existing_names |= set(node.input) | |
| existing_names |= set(node.output) | |
| def _find_name(prefix): | |
| if prefix not in existing_names: | |
| existing_names.add(prefix) | |
| return prefix | |
| i = 2 | |
| while True: | |
| name = f"{prefix}_{i}" | |
| if name not in existing_names: | |
| existing_names.add(name) | |
| return name | |
| i += 1 | |
| # The function should never go through that line. | |
| raise RuntimeError("The function should never go through that line.") | |
| cst0 = make_node("Constant", [], [_find_name("zero")], value_int=0) | |
| cst1 = make_node("Constant", [], [_find_name("one")], value_int=1) | |
| update = {} | |
| for inode, node in enumerate(nodes): | |
| if node.op_type != "ConstantOfShape": | |
| continue | |
| shape = node.input[0] | |
| n = make_node("ReduceProd", [shape], [_find_name(f"{shape}_N")]) | |
| a = make_node( | |
| "Range", | |
| [cst0.output[0], n.output[0], cst1.output[0]], | |
| [_find_name(f"{shape}_RANGE")], | |
| ) | |
| if len(node.attribute) == 1: | |
| to = node.attribute[0].t.data_type | |
| else: | |
| to = TensorProto.FLOAT | |
| ac = make_node("Cast", [a.output[0]], [_find_name(f"{shape}_RANGEf")], to=to) | |
| cl = make_node("Cast", [n.output[0]], [_find_name(f"{shape}_Nf")], to=to) | |
| d = make_node( | |
| "Div", [ac.output[0], cl.output[0]], [_find_name(f"{shape}_FLAT")] | |
| ) | |
| resh = make_node("Reshape", [d.output[0], shape], node.output) | |
| update[inode] = [n, a, ac, cl, d, resh] | |
| for inode, up in sorted(update.items(), reverse=True): | |
| nodes[inode : inode + 1] = up | |
| nodes.insert(0, cst0) | |
| nodes.insert(1, cst1) | |
| if isinstance(onx, GraphProto): | |
| graph = make_graph( | |
| nodes, | |
| onx.name, | |
| onx.input, | |
| onx.output, | |
| initializer=onx.initializer, | |
| sparse_initializer=onx.sparse_initializer, | |
| ) | |
| return graph | |
| if isinstance(onx, FunctionProto): | |
| new_onx = make_function( | |
| onx.domain, | |
| onx.name, | |
| onx.input, | |
| onx.output, | |
| nodes, | |
| opset_imports=onx.opset_import, | |
| ) | |
| return new_onx | |
| raise TypeError(f"Not implemented for type {type(onx)}.") | |
| def _replace_constant_of_shape_value( | |
| onx: Union[GraphProto, FunctionProto], value_constant_of_shape: float | |
| ) -> Union[GraphProto, FunctionProto]: | |
| """Replaces all fill value of all nodes *ConstantOfShape*.""" | |
| if isinstance(onx, GraphProto): | |
| nodes = list(onx.node) | |
| elif isinstance(onx, FunctionProto): | |
| nodes = list(onx.node) | |
| else: | |
| raise TypeError(f"Not implemented for type {type(onx)}.") | |
| existing_names = set() | |
| for node in nodes: | |
| existing_names |= set(node.input) | |
| existing_names |= set(node.output) | |
| update = {} | |
| for inode, node in enumerate(nodes): | |
| if node.op_type != "ConstantOfShape": | |
| continue | |
| tensor = node.attribute[0].t | |
| new_tensor = make_tensor( | |
| tensor.name, tensor.data_type, [1], [value_constant_of_shape] | |
| ) | |
| new_node = make_node("ConstantOfShape", node.input, node.output) | |
| att = make_attribute(node.attribute[0].name, value=new_tensor) | |
| new_node.attribute.append(att) | |
| update[inode] = new_node | |
| for inode, up in update.items(): | |
| nodes[inode] = up | |
| if isinstance(onx, GraphProto): | |
| graph = make_graph( | |
| nodes, | |
| onx.name, | |
| onx.input, | |
| onx.output, | |
| initializer=onx.initializer, | |
| sparse_initializer=onx.sparse_initializer, | |
| ) | |
| return graph | |
| if isinstance(onx, FunctionProto): | |
| new_onx = make_function( | |
| onx.domain, | |
| onx.name, | |
| onx.input, | |
| onx.output, | |
| nodes, | |
| opset_imports=onx.opset_import, | |
| ) | |
| return new_onx | |
| raise TypeError(f"Not implemented for type {type(onx)}.") | |
| def replace_initializer_by_constant_of_shape( # noqa: PLR0911 | |
| onx: Union[FunctionProto, GraphProto, ModelProto], | |
| threshold: int = 128, | |
| ir_version: Optional[int] = None, | |
| use_range: bool = False, | |
| value_constant_of_shape: float = 0.5, | |
| ): | |
| """Replace initializers or constant node by nodes *ConstantOfShape* to reduce the size. | |
| This reduce the cost to write a unit test about a specific graph structure. | |
| Args: | |
| onx: ModelProto | |
| threshold: every initializer under this threshold is not | |
| impacted | |
| ir_version: initializer must be specified as input for | |
| `ir_version <= 3`, this must be specified if onx is | |
| :class:`FunctionProto` or :class:`GraphProto` | |
| use_range: if uses operator *Range* instead of *ConstantOfShape* | |
| to avoid constant tensors | |
| value_constant_of_shape: value to use as a value for all nodes | |
| *ConstantOfShape*, a high value may produce nan or inf | |
| predictions | |
| Returns: | |
| onx, modified ModelProto | |
| The function is designed so that the function can be reapplied on a modified model | |
| and either replace *ConstantOfShape* with *Range* operators, either replace the fill value | |
| for every *ConstantOfShape*. | |
| """ | |
| if isinstance(onx, FunctionProto): | |
| modified = False | |
| new_nodes: List[NodeProto] = [] | |
| for node in onx.node: | |
| if node.op_type == "Constant": | |
| cst_nodes = _replace_constant(node, threshold, value_constant_of_shape) | |
| if len(cst_nodes) == 2: # noqa: PLR2004 | |
| modified = True | |
| new_nodes.extend(cst_nodes) | |
| continue | |
| new_nodes.append(node) | |
| if modified: | |
| new_onx = make_function( | |
| onx.domain, | |
| onx.name, | |
| onx.input, | |
| onx.output, | |
| new_nodes, | |
| opset_imports=onx.opset_import, | |
| ) | |
| if use_range: | |
| return _replace_constant_of_shape_with_range(new_onx) | |
| if value_constant_of_shape != 1: | |
| return _replace_constant_of_shape_value( | |
| new_onx, value_constant_of_shape | |
| ) | |
| return new_onx | |
| if use_range: | |
| return _replace_constant_of_shape_with_range(onx) | |
| if value_constant_of_shape != 1: | |
| return _replace_constant_of_shape_value(onx, value_constant_of_shape) | |
| return onx | |
| if isinstance(onx, ModelProto): | |
| new_graph = replace_initializer_by_constant_of_shape( | |
| onx.graph, | |
| ir_version=ir_version or onx.ir_version, | |
| threshold=threshold, | |
| use_range=use_range, | |
| value_constant_of_shape=value_constant_of_shape, | |
| ) | |
| new_functions = [ | |
| replace_initializer_by_constant_of_shape( | |
| f, | |
| threshold=threshold, | |
| ir_version=ir_version or onx.ir_version, | |
| use_range=use_range, | |
| value_constant_of_shape=value_constant_of_shape, | |
| ) | |
| for f in onx.functions | |
| ] | |
| model = make_model( | |
| new_graph, | |
| functions=new_functions, | |
| producer_name=onx.producer_name, | |
| producer_version=onx.producer_version, | |
| ir_version=ir_version or onx.ir_version, | |
| doc_string=onx.doc_string, | |
| domain=onx.domain, | |
| model_version=onx.model_version, | |
| ) | |
| if len(onx.metadata_props) > 0: # pragma: no cover | |
| values = {p.key: p.value for p in onx.metadata_props} | |
| set_model_props(model, values) | |
| del model.opset_import[:] | |
| for oimp in onx.opset_import: | |
| op_set = model.opset_import.add() | |
| if oimp.domain == "" and oimp.version < 11 and use_range: # noqa: PLR2004 | |
| raise RuntimeError( | |
| f"Range was introduced in opset 11 but opset is {oimp.version}." | |
| ) | |
| if oimp.domain == "" and oimp.version < 9: # noqa: PLR2004 | |
| raise RuntimeError( | |
| f"ConstantOfShape was introduced in " | |
| f"opset 9 but opset is {oimp.version}." | |
| ) | |
| op_set.domain = oimp.domain | |
| op_set.version = oimp.version | |
| return model | |
| if not isinstance(onx, GraphProto): | |
| raise TypeError(f"onx should be a GraphProto at this stage not {type(onx)}.") | |
| n_modifications = 0 | |
| new_nodes = [] | |
| removed = set() | |
| additional_inputs = [] | |
| new_inits: List[TensorProto] = [] | |
| for init in onx.initializer: | |
| dims = tuple(init.dims) | |
| size = np.prod(dims) | |
| if size <= threshold: | |
| new_inits.append(init) | |
| continue | |
| n_modifications += 1 | |
| new_name = f"{init.name}__SHAPE" | |
| new_inits.append( | |
| from_array(np.array(list(dims), dtype=np.int64), name=new_name) | |
| ) | |
| dtype = tensor_dtype_to_np_dtype(init.data_type) | |
| node = make_node( | |
| "ConstantOfShape", | |
| [new_name], | |
| [init.name], | |
| value=from_array(np.array([0.5], dtype=dtype)), | |
| ) | |
| new_nodes.append(node) | |
| removed.add(init.name) | |
| if ir_version is not None and ir_version <= 3: # noqa: PLR2004 | |
| additional_inputs.append( | |
| make_tensor_value_info(new_name, TensorProto.INT64, [len(dims)]) | |
| ) | |
| new_sparse_inits: List[SparseTensorProto] = [] | |
| for sp_init in onx.sparse_initializer: | |
| dims = tuple(sp_init.dims) | |
| size = np.prod(dims) | |
| if size <= threshold: | |
| new_sparse_inits.append(sp_init) | |
| continue | |
| raise NotImplementedError( | |
| f"This feature is not yet implemented for a sparse initializer " | |
| f"(indices.name={sp_init.indices.name!r}, " | |
| f"values.name={sp_init.values.name!r})." | |
| ) | |
| for node in onx.node: | |
| if node.op_type == "Constant": | |
| shape_nodes = _replace_constant(node, threshold, value_constant_of_shape) | |
| if len(shape_nodes) == 2: # noqa: PLR2004 | |
| n_modifications += 1 | |
| new_nodes.extend(shape_nodes) | |
| continue | |
| modified = False | |
| atts = [] | |
| for att in node.attribute: | |
| if ( | |
| att.type == AttributeProto.GRAPH | |
| and hasattr(att, "g") | |
| and att.g is not None | |
| ): | |
| g = replace_initializer_by_constant_of_shape( | |
| att.g, | |
| threshold=threshold, | |
| ir_version=ir_version, | |
| use_range=use_range, | |
| value_constant_of_shape=value_constant_of_shape, | |
| ) | |
| if id(g) != id(att.g): | |
| modified = True | |
| att = make_attribute(att.name, g) # noqa: PLW2901 | |
| atts.append(att) | |
| if modified: | |
| new_node = make_node(node.op_type, node.input, node.output) | |
| new_node.attribute.extend(atts) | |
| new_nodes.append(new_node) | |
| n_modifications += 1 | |
| else: | |
| new_nodes.append(node) | |
| if n_modifications > 0: | |
| graph = make_graph( | |
| new_nodes, | |
| onx.name, | |
| [i for i in onx.input if i.name not in removed] + additional_inputs, | |
| onx.output, | |
| initializer=new_inits, | |
| sparse_initializer=new_sparse_inits, | |
| ) | |
| if use_range: | |
| return _replace_constant_of_shape_with_range(graph) | |
| if value_constant_of_shape != 1: | |
| return _replace_constant_of_shape_value(graph, value_constant_of_shape) | |
| return graph | |
| if use_range: | |
| return _replace_constant_of_shape_with_range(onx) | |
| if value_constant_of_shape != 1: | |
| return _replace_constant_of_shape_value(onx, value_constant_of_shape) | |
| return onx | |