Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import logging | |
| import queue | |
| import re | |
| from collections import defaultdict | |
| import torch | |
| from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy | |
| CLASSTYPE_KIND = 'ClassType' | |
| GETATTR_KIND = 'prim::GetAttr' | |
| CAT_KIND = 'aten::cat' | |
| LIST_CONSTRUCT_KIND = 'prim::ListConstruct' | |
| LIST_UNPACK_KIND = 'prim::ListUnpack' | |
| TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct' | |
| TUPLE_UNPACK_KIND = 'prim::TupleUnpack' | |
| _logger = logging.getLogger(__name__) | |
| def build_module_graph(model, dummy_input): | |
| return TorchModuleGraph(model, dummy_input) | |
| def build_graph(model, dummy_input, verbose=False): | |
| g = TorchProtoGraph(model, dummy_input, verbose) | |
| return g.graph_def, g.stepstats | |
| def parse_traced_name(module_name): | |
| prefix = 'TracedModule[' | |
| suffix = ']' | |
| if module_name.startswith(prefix) and module_name.endswith(suffix): | |
| module_name = module_name[len(prefix):-len(suffix)] | |
| return module_name | |
| class TorchGraph: | |
| """ | |
| This class is to extract pytorch model topology graph by tracing | |
| """ | |
| def __init__(self, model=None, dummy_input=None, traced_model=None): | |
| """ | |
| Parameters | |
| ---------- | |
| model : pytorch model | |
| The model user wants to speed up | |
| dummy_input : pytorch tensor | |
| The dummy input for ```jit.trace```, users should put it on right device before pass in | |
| traced_model : torch._C.torch.jit.TopLevelTracedModule | |
| An alredy traced model, if traced_model is not None, then TorchGraph will build the graph | |
| based on this traced model and won't trace the model again. | |
| """ | |
| assert torch.__version__ >= '1.3.1' | |
| # check if the input is legal | |
| if traced_model is not None: | |
| assert isinstance(traced_model, torch.jit.TopLevelTracedModule) | |
| self.trace = traced_model | |
| # it's ok if the graph is already unpacked | |
| torch._C._jit_pass_inline(self.trace.graph) | |
| elif model is not None and dummy_input is not None: | |
| self.bound_model = model | |
| self._trace(model, dummy_input) | |
| else: | |
| raise Exception( | |
| 'Please provide model & dummy_input or the traced_model as inputs') | |
| def _trace(self, model, dummy_input): | |
| with torch.onnx.set_training(model, False): | |
| # import torch.jit | |
| self.trace = torch.jit.trace(model, dummy_input, check_trace=False) | |
| torch._C._jit_pass_inline(self.trace.graph) | |
| class TorchProtoGraph(TorchGraph): | |
| """ | |
| Generates model graph for pytorch models in protobuf, this implementation | |
| is borrowed from pytorch v1.4.0, and fixed following issues: | |
| https://github.com/pytorch/pytorch/issues/33691 | |
| https://github.com/pytorch/pytorch/issues/33670 | |
| """ | |
| def __init__(self, model, dummy_input, verbose=False): | |
| super().__init__(model, dummy_input) | |
| from tensorboard.compat.proto.config_pb2 import RunMetadata | |
| from tensorboard.compat.proto.graph_pb2 import GraphDef | |
| from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats | |
| from tensorboard.compat.proto.versions_pb2 import VersionDef | |
| list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input) | |
| if verbose: | |
| print(self.trace.graph) | |
| self.stepstats = RunMetadata(step_stats=StepStats( | |
| dev_stats=[DeviceStepStats(device="/device:CPU:0")])) | |
| self.graph_def = GraphDef( | |
| node=list_of_nodes, versions=VersionDef(producer=22)) | |
| def parse(self, graph, trace, args=None, omit_useless_nodes=True): | |
| """This method parses an optimized PyTorch model graph and produces | |
| a list of nodes and node stats for eventual conversion to TensorBoard | |
| protobuf format. | |
| Args: | |
| graph (PyTorch module): The model graph to be parsed. | |
| trace (PyTorch JIT TracedModule): The model trace to be parsed. | |
| args (tuple): input tensor[s] for the model. | |
| omit_useless_nodes (boolean): Whether to remove nodes from the graph. | |
| """ | |
| nodes_py = GraphPy() | |
| for node in graph.inputs(): | |
| if omit_useless_nodes: | |
| if not node.uses(): # number of user of the node (= number of outputs/ fanout) | |
| continue | |
| if node.type().kind() != CLASSTYPE_KIND: | |
| nodes_py.append(NodePyIO(node, 'input')) | |
| attr_to_scope = dict() | |
| def node_to_name(d): | |
| return str(d).split(":")[0].strip() | |
| for node in graph.nodes(): | |
| if node.kind() == GETATTR_KIND: | |
| attr_name = node.s('name') | |
| node_name = node_to_name(node) | |
| parent = node.input().node() | |
| # If the parent node is not the top-level "self" node | |
| if parent.kind() == GETATTR_KIND: | |
| parent_scope = attr_to_scope[node_to_name(parent)] | |
| attr_scope = parent_scope.split('/')[-1] | |
| attr_to_scope[node_name] = '{}/{}.{}'.format( | |
| parent_scope, attr_scope, attr_name) | |
| else: | |
| attr_to_scope[node_name] = '__module.{}'.format(attr_name) | |
| # We don't need classtype nodes; scope will provide this information | |
| if node.output().type().kind() != CLASSTYPE_KIND: | |
| node_py = NodePyOP(node) | |
| node_py.scopeName = attr_to_scope[node_name] | |
| nodes_py.append(node_py) | |
| else: | |
| nodes_py.append(NodePyOP(node)) | |
| # Create sink nodes for output ops | |
| for i, node in enumerate(graph.outputs()): | |
| node_py = NodePyIO(node, 'output') | |
| node_py.debugName = "output.{}".format(i + 1) | |
| node_py.inputs = [node.debugName()] | |
| nodes_py.append(node_py) | |
| alias_to_name = dict() | |
| base_name = parse_traced_name(trace._name) | |
| for name, module in trace.named_modules(prefix='__module'): | |
| mod_name = parse_traced_name(module._name) | |
| attr_name = name.split('.')[-1] | |
| alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name) | |
| for node in nodes_py.nodes_op: | |
| module_aliases = node.scopeName.split('/')[-1].split('.') | |
| module_name = '' | |
| for i, alias in enumerate(module_aliases): | |
| if i == 0: | |
| module_name = alias | |
| node.scopeName = base_name | |
| else: | |
| module_name += '.' + alias | |
| node.scopeName += '/' + \ | |
| (alias_to_name[module_name] | |
| if module_name in alias_to_name else alias) | |
| nodes_py.populate_namespace_from_OP_to_IO() | |
| return nodes_py.to_proto() | |
| class NodePyGroup(NodePy): | |
| """ | |
| This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph, | |
| there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to | |
| represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node. | |
| """ | |
| def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None, key_node=None): | |
| """ | |
| Parameters: | |
| ----------- | |
| name: str | |
| node name, such as `conv1`, `backbone.classifier` | |
| unique_name: str | |
| A global unique name for current node. Due to some modules, | |
| such as relu, may be reused several times, so the scopename | |
| is not suitable as the global unique identifier, so we add a | |
| unique_name for each node as the global unique identifier. | |
| We should use the unique_name to traverset the module graph. | |
| node_type: str | |
| `module` or `func` | |
| op_type: str | |
| operation type, such as `Conv2d`, `aten::view` | |
| node_cpps: list of torch._C.Node | |
| jit trace nodes which are included in this new node | |
| inputs: list of str | |
| All the inputs of this node, each element is debugName of one input | |
| outputs: list of str | |
| All the outputs of this node, each element is debugName of one output | |
| key_node: torch._C.Node | |
| The key node of this NodePyGroup. | |
| """ | |
| super(NodePyGroup, self).__init__(name, []) | |
| self.node_cpps = node_cpps | |
| self.name = name | |
| self.unique_name = unique_name | |
| self.op_type = op_type | |
| self.type = node_type | |
| self.nodes = [] | |
| self.auxiliary = None | |
| self.add_nodes(node_cpps) | |
| self.inputs = inputs | |
| self.outputs = outputs | |
| # The core node in this NodePyGroup | |
| self.key_node = key_node | |
| def add_nodes(self, node_cpps): | |
| for node_cpp in node_cpps: | |
| nodepy = NodePyOP(node_cpp) | |
| nodepy.name = node_cpp.scopeName() + '_' + node_cpp.kind() | |
| self.nodes.append(nodepy) | |
| def sub_node_names(self): | |
| return [x.name for x in self.nodes] | |
| def __repr__(self): | |
| return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format( | |
| self.name, self.type, self.op_type, self.sub_node_names(), | |
| self.inputs, self.outputs, self.auxiliary | |
| ) | |
| class TorchModuleGraph(TorchGraph): | |
| """ | |
| Generates model graph, each node is created from single or multiple jit trace nodes. | |
| """ | |
| def __init__(self, model=None, dummy_input=None, traced_model=None): | |
| super().__init__(model, dummy_input, traced_model) | |
| self.global_count = 0 | |
| self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() | |
| self._extract_auxiliary_info() | |
| def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node, | |
| module_type): | |
| """ | |
| For trace graph nodes, some nodes are not in modules, these nodes are usually generated by | |
| the functions directly called in module ```forward```. For such nodes, some of them are | |
| trivial op which are label by ```prim::```, some of them are not such ops which is call | |
| non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct | |
| a node. | |
| Parameters | |
| ---------- | |
| node : trace graph node | |
| The non-prim node to expand | |
| nodes : list of trace graph node | |
| All the trace graph nodes within the same scope as the non-prim node | |
| input_to_node : dict | |
| key: input name, value: a node that uses this input | |
| output_to_node : dict | |
| key: output name, value: a node that generates this output | |
| module_type : str | |
| can be 'module' or 'func' | |
| Returns | |
| ------- | |
| node | |
| the expanded non-prim node | |
| """ | |
| # TODO: scope name could be empty | |
| node_name = '.'.join([self._get_module_name( | |
| node.scopeName()), node.kind(), str(self.global_count)]) | |
| unique_name = node_name | |
| _logger.debug("expand non-prim node, node name: %s", node_name) | |
| self.global_count += 1 | |
| op_type = node.kind() | |
| node_group = [node] | |
| inputs = list() | |
| outputs = list() | |
| node_queue = queue.Queue() | |
| node_queue.put(node) | |
| while not node_queue.empty(): | |
| curr_node = node_queue.get() | |
| for _input in curr_node.inputs(): | |
| input_name = _input.debugName() | |
| if input_name in output_to_node and output_to_node[input_name] in nodes: | |
| predecessor_node = output_to_node[input_name] | |
| if not self._is_key_func(predecessor_node): | |
| node_group.append(predecessor_node) | |
| node_queue.put(predecessor_node) | |
| else: | |
| inputs.append(input_name) | |
| else: | |
| inputs.append(input_name) | |
| for output in node.outputs(): | |
| outputs.append(output.debugName()) | |
| nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, | |
| node_group, inputs=inputs, outputs=outputs, key_node=node) | |
| return nodepy | |
| def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, | |
| input_to_node, output_to_node, module_type): | |
| """ | |
| merge the adjacent nodes of the module. The difference between the | |
| _expand_module_node and _expand_non_prim_node is that, the _expand_non_prim_node | |
| only merge the prim:: nodes into the aten:: node, in contrast,the _expand_module_node | |
| will merge all adjacent nodes into a same nodepy group. | |
| Parameters | |
| ---------- | |
| node : trace graph node | |
| The non-prim node to expand | |
| node_name : str | |
| specify the node_name for NodePyGroup | |
| unique_name : str | |
| unique_name for the NodePyGroup | |
| op_type : str | |
| specify the op_type for the NodePyGroup | |
| nodes : list of trace graph node | |
| All the trace graph nodes within the same scope as the non-prim node | |
| input_to_node : dict | |
| key: input name, value: a node that uses this input | |
| output_to_node : dict | |
| key: output name, value: a node that generates this output | |
| module_type : str | |
| can be 'module' or 'func' | |
| Returns | |
| ------- | |
| node | |
| the expanded non-prim node | |
| """ | |
| _logger.debug("expand module node, node name: %s", node_name) | |
| self.global_count += 1 | |
| if not op_type: | |
| op_type = node.kind() | |
| node_group = [node] | |
| inputs = list() | |
| outputs = list() | |
| node_queue = queue.Queue() | |
| node_queue.put(node) | |
| visited = {node} | |
| while not node_queue.empty(): | |
| curr_node = node_queue.get() | |
| for _input in curr_node.inputs(): | |
| input_name = _input.debugName() | |
| if input_name in output_to_node and output_to_node[input_name] in nodes: | |
| predecessor_node = output_to_node[input_name] | |
| if predecessor_node not in visited: | |
| node_group.append(predecessor_node) | |
| node_queue.put(predecessor_node) | |
| visited.add(predecessor_node) | |
| else: | |
| inputs.append(input_name) | |
| for _output in curr_node.outputs(): | |
| output_name = _output.debugName() | |
| if output_name in input_to_node and input_to_node[output_name] in nodes: | |
| successor_node = input_to_node[output_name] | |
| if successor_node not in visited: | |
| node_group.append(successor_node) | |
| node_queue.put(successor_node) | |
| visited.add(successor_node) | |
| else: | |
| outputs.append(output_name) | |
| nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, | |
| node_group, inputs=inputs, outputs=outputs) | |
| return nodepy | |
| def _extract_cat_info(self, node_group, cpp_node): | |
| """ | |
| Extract the detail information of the cat operation, | |
| such the order of the input tensor, the shape of each | |
| input tensor, the output shape, and the cat dimension. | |
| Parameters | |
| ---------- | |
| node_group : NodePyGroup | |
| cpp_node: torch._C.Node | |
| It should be ```aten::cat``` node | |
| Returns | |
| ------- | |
| dict | |
| Include auxiliary information for the cat operation. | |
| This dict objec has four keys: 'cat_dim', 'out_shape', | |
| 'in_order' and 'in_shape'. cat_dim is the dimension of | |
| the cat operation to concat the input tensors. out_shape | |
| is the shape of the output tensor of the cat operation. | |
| in_order is an ordered list which contains the corresponding | |
| parent operaion nodes of the input tensors. in_shape is also | |
| an ordered list that contains the input shapes of the input | |
| tensor. | |
| """ | |
| # only suport the cat operation | |
| assert cpp_node.kind() == CAT_KIND | |
| cat_info = {} | |
| # get the shape of the output tensor | |
| t_output = cpp_node.output() | |
| out_shape = t_output.type().sizes() | |
| cat_info['out_shape'] = out_shape | |
| # get the cat dimension | |
| inputs = cpp_node.inputs() | |
| cat_dim = list(inputs)[1].toIValue() | |
| cat_info['cat_dim'] = cat_dim | |
| # get the order of the input tensors | |
| # To get the order of the input tensors, we need | |
| # to be aware of the topology of the model, which | |
| # means we should extract the auxiliary information | |
| # after the build_index function. | |
| input_order = [] | |
| list_construct_cpp = list(cpp_node.inputs())[0].node() | |
| input_tensors = list(list_construct_cpp.inputs()) | |
| for _tensor in input_tensors: | |
| debug_name = _tensor.debugName() | |
| input_order.append(self.output_to_node[debug_name].unique_name) | |
| cat_info['in_order'] = input_order | |
| input_shapes = [t.type().sizes() for t in input_tensors] | |
| cat_info['in_shape'] = input_shapes | |
| return cat_info | |
| def _extract_linear_shape_info(self, node_group): | |
| """ | |
| Extract linear shape input/output tensor shape info from its aten::addmm op. | |
| Parameters | |
| ---------- | |
| node_group : NodePyGroup | |
| NodePyGroup object associated with the linear module. | |
| Returns | |
| ------- | |
| dict | |
| Include shape of input tensor and shape of output tensor | |
| """ | |
| for cpp_node in node_group.node_cpps: | |
| if cpp_node.kind() == 'aten::addmm': | |
| # https://github.com/pytorch/pytorch/blob/1.6/torch/nn/functional.py#L1682 | |
| # inputs of aten::addmm: | |
| # inputs[0] is bias | |
| # inputs[1] is input data | |
| # inputs[2] is weight | |
| t_input = list(cpp_node.inputs())[1] | |
| t_output = cpp_node.output() | |
| assert isinstance(t_input.type(), torch._C.TensorType) | |
| assert isinstance(t_output.type(), torch._C.TensorType) | |
| in_shape = t_input.type().sizes() | |
| out_shape = t_output.type().sizes() | |
| return {'in_shape': in_shape, 'out_shape': out_shape} | |
| return None | |
| def _extract_shape_info(self, node): | |
| """ | |
| Extract the shape information of ```aten::view``` node | |
| Parameters | |
| ---------- | |
| node : trace graph node | |
| It should be ```aten::view``` node | |
| Returns | |
| ------- | |
| dict | |
| Include shape of input tensor and shape of output tensor | |
| """ | |
| t_input = None | |
| for _input in node.inputs(): | |
| t_input = _input | |
| break | |
| t_output = node.output() | |
| assert isinstance(t_input.type(), torch._C.TensorType) | |
| assert isinstance(t_output.type(), torch._C.TensorType) | |
| in_shape = t_input.type().sizes() | |
| out_shape = t_output.type().sizes() | |
| return {'in_shape': in_shape, 'out_shape': out_shape} | |
| def _extract_leaf_modules(self): | |
| """ | |
| Extract leaf modules from the given graph. Leaf module means it does not have submodules. | |
| To extract leaf modules because only leaf module can be replaced. And shape inference can | |
| be done in leaf module level. Other shape inference is done in lower level i.e., | |
| operation level. | |
| Returns | |
| ------- | |
| list | |
| a list of scope name of all the leaf modules | |
| """ | |
| def is_parent(name1, name2): | |
| """ | |
| check if name1 is parent node of name2, for example: | |
| name1: aa.bb, name2: aa.bb.cc, return True | |
| name1: aa.b, name2: aa.bb, return False | |
| """ | |
| parts1, parts2 = name1.split('.'), name2.split('.') | |
| if len(parts1) >= len(parts2): | |
| return False | |
| for i, _ in enumerate(parts1): | |
| if parts2[i] != parts1[i]: | |
| return False | |
| return True | |
| module_names = sorted([x[0] | |
| for x in self.trace.named_modules() if x[0]]) | |
| leaf_nodes = [] | |
| for i, name in enumerate(module_names): | |
| if i + 1 >= len(module_names) or not is_parent(name, module_names[i + 1]): | |
| leaf_nodes.append(name) | |
| return leaf_nodes | |
| def _get_module_name(self, scope_name): | |
| """ | |
| Retrieve module name from scope name. | |
| Parameters: | |
| ----------- | |
| scope_name: str | |
| scope_name of a graph node, for example: | |
| for pytorch 1.3.1: MyModel/BackboneModel[backbone]/Conv2d[conv2] | |
| for pytorch 1.4.0: __module.backbone/__module.backbone.conv2 | |
| Returns: | |
| ------- | |
| str | |
| module name, such as backbone.conv2 | |
| """ | |
| if torch.__version__ >= '1.4.0': | |
| return scope_name.split('/')[-1].replace('__module.', '') | |
| else: | |
| return '.'.join(re.findall(r'\[(.*?)\]', scope_name)) | |
| def _build_index(self, nodes_op): | |
| name_to_node = dict() | |
| input_to_node = defaultdict(list) | |
| output_to_node = dict() | |
| for node in nodes_op: | |
| name_to_node[node.unique_name] = node | |
| for _input in node.inputs: | |
| input_to_node[_input].append(node) | |
| for output in node.outputs: | |
| assert not output in output_to_node, \ | |
| "One output cannot be generated by multiple nodes" | |
| output_to_node[output] = node | |
| return name_to_node, input_to_node, output_to_node | |
| def _is_key_func(self, node_cpp): | |
| """ | |
| Judge if a cpp node is a key function node. | |
| If so, we should not merge this node into the | |
| adjacent node. | |
| """ | |
| if node_cpp.kind().startswith('aten::'): | |
| # the nodes that start with 'aten' are key function | |
| # nodes | |
| return True | |
| if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]: | |
| # We cannot merge the List/Tuple | |
| # Unpack func into other nodes, else it | |
| # may lead to a graph construction error. | |
| # The reason why we donnot take the construct node | |
| # also as a key node is that `cat` operation node need | |
| # the last(previous) visited node to infer the mask. If | |
| # we take the Construct node as the important node, the | |
| # predecessor of the `cat` node will always be a construct | |
| # node, which means we cannot infer the mask for the cat | |
| # operation. | |
| return True | |
| return False | |
| def unpack_manually(self): | |
| """ | |
| Unpack the tensor tuple or tensor list manually, | |
| and remove the ListUnpack/TupleUnpack node from | |
| the graph. Note: this function will change the | |
| graph structure. | |
| """ | |
| if hasattr(self, 'unpacked'): | |
| # if already unpacked the tuple/list manually | |
| return | |
| for node in self.nodes_py.nodes_op: | |
| if node.op_type in [TUPLE_UNPACK_KIND, LIST_UNPACK_KIND]: | |
| unpack_cpp = node.key_node | |
| last_cpp = list(unpack_cpp.inputs())[0].node() | |
| if last_cpp.kind() in [TUPLE_CONSTRUCT_KIND, LIST_CONSTRUCT_KIND]: | |
| # we need check if the tensor tuple or tensor list is produced | |
| # by a list/tuple construct node. If so, we can unpack the tuple | |
| # or list manunally. | |
| _logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp)) | |
| _logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp)) | |
| assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs())) | |
| errmsg = '%s Input number: %d if inconsistent with the output number %d' % (unpack_cpp, \ | |
| len(node.inputs), len(list(last_cpp.inputs()))) | |
| assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg | |
| for _debug_input, _debug_output in zip(node.inputs, node.outputs): | |
| # _debug_input = _input.debugName() | |
| # _debug_output = _output.debugName() | |
| if _debug_input in self.input_to_node and _debug_output in self.input_to_node: | |
| # input_to_node[_debug_input] is a list of NodePyGroup, because | |
| # one tensor can be used as input for multiple nodes at the same time. | |
| # note that, in this case, the construct cpp node and unpack cpp node | |
| # will be merged into the same NodePyGroup, so we remove the `node` from | |
| # input_to_node[_debug_input] and directly connect this tensor to the | |
| # input_to_node[_debug_output] | |
| self.input_to_node[_debug_input].remove(node) | |
| # add the following nodes of _output into the input_to_node[_debug_input] | |
| self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output]) | |
| # just remove the _debug_output from the grapgh index. So that we can also skip | |
| # the construct and tuple | |
| if _debug_output in self.input_to_node: | |
| for following_node in self.input_to_node[_debug_output]: | |
| _tmp_index = following_node.inputs.index(_debug_output) | |
| following_node.inputs[_tmp_index] = _debug_input | |
| self.unpacked = True | |
| def _build_graph(self): | |
| """ | |
| Build graph using our defined format from jit trace. | |
| There are basically three steps: first, construct necessary information (data structures), | |
| second, extract all the modules to convert to node, Third, extract all functions to convert | |
| to node. | |
| Returns | |
| ------- | |
| dict | |
| use name to index nodes, key: node name, value: node | |
| dict | |
| use input (its name) to index nodes, | |
| key: input, value: list of nodes that take this input | |
| dict | |
| use output (its name) to index nodes, | |
| key: output, value: node that generates this output | |
| """ | |
| omit_useless_nodes = True | |
| graph = self.trace.graph | |
| # _logger.debug(graph) | |
| # build output mapping, from output debugName to its node | |
| output_to_node = {x.debugName(): n for n in graph.nodes() | |
| for x in n.outputs()} | |
| # build input mapping, from input debugName to its node | |
| input_to_node = {x.debugName(): n for n in graph.nodes() | |
| for x in n.inputs()} | |
| # build module mapping, from module name to all nodes (as list) under this module scope | |
| module_to_nodes = defaultdict(list) | |
| # the mapping of function (non-module in forward) to nodes, key is scope name | |
| func_to_nodes = defaultdict(list) | |
| nodes_py = GraphPy() | |
| for node in graph.inputs(): | |
| if omit_useless_nodes: | |
| if not node.uses(): # number of user of the node (= number of outputs/ fanout) | |
| continue | |
| if node.type().kind() != 'ClassType': | |
| nodes_py.append(NodePyIO(node, 'input')) | |
| self.leaf_modules = self._extract_leaf_modules() | |
| module_to_type = {name: parse_traced_name( | |
| module._name) for name, module in self.trace.named_modules()} | |
| # associate module name with their trace graph nodes | |
| for node in graph.nodes(): | |
| module_name = self._get_module_name(node.scopeName()) | |
| if module_name in self.leaf_modules: | |
| module_to_nodes[module_name].append(node) | |
| else: | |
| func_to_nodes[node.scopeName()].append(node) | |
| # build node group for module | |
| for module_name, node_cpps in module_to_nodes.items(): | |
| use_count = 0 | |
| merged = set() | |
| for node in node_cpps: | |
| if node not in merged: | |
| # modules that have same scope name may have different locations in the | |
| # graph. Futhermore, there are also lots of prim:: nodes that in node_cpps, | |
| # so we also need to call the expand_module_node. | |
| unique_name = module_name | |
| if use_count > 0: | |
| unique_name = module_name + '.%d' % use_count | |
| node_group = self._expand_module_node( | |
| node, module_name, unique_name, module_to_type[module_name], | |
| node_cpps, input_to_node, output_to_node, 'module') | |
| nodes_py.nodes_op.append(node_group) | |
| use_count += 1 | |
| merged.update(node_group.node_cpps) | |
| # each scope_name may have multiple funcs, we split them and create node for each of them | |
| # build node group for torch.nn.functional | |
| for _, nodes in func_to_nodes.items(): | |
| # extract non prim:: nodes | |
| key_func_nodes = list() | |
| for node in nodes: | |
| if self._is_key_func(node): | |
| # find the key function nodes | |
| key_func_nodes.append(node) | |
| # for each non prim node, expand it | |
| for node in key_func_nodes: | |
| node_group = self._expand_key_func_node( | |
| node, nodes, input_to_node, output_to_node, 'func') | |
| nodes_py.nodes_op.append(node_group) | |
| # get shape infor for view (aten::view) func | |
| # if node_group.op_type in ['aten::view', 'aten::flatten']: | |
| # node_group.auxiliary = self._extract_shape_info(node) | |
| for node in graph.outputs(): # Create sink nodes for output ops | |
| node_py = NodePyIO(node, 'output') | |
| nodes_py.append(node_py) | |
| self.nodes_py = nodes_py | |
| # build index | |
| return self._build_index(self.nodes_py.nodes_op) | |
| def _extract_auxiliary_info(self): | |
| """ | |
| Extract the auxiliary information for the nodegroups | |
| if necessary. For example, view/flatten operations may | |
| need the shape of the input tensor and output tensor. | |
| """ | |
| # extract the input & output shape for the view and flatten | |
| for node_group in self.nodes_py.nodes_op: | |
| if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']: | |
| # get shape infor for view (aten::view) func | |
| cpp_node = list(filter(lambda x: x.kind() == node_group.op_type, | |
| node_group.node_cpps))[0] | |
| node_group.auxiliary = self._extract_shape_info(cpp_node) | |
| elif node_group.op_type == 'Linear': | |
| node_group.auxiliary = self._extract_linear_shape_info(node_group) | |
| elif node_group.op_type == CAT_KIND: | |
| # get the detail information for cat func | |
| cpp_node = list(filter(lambda x: x.kind() == node_group.op_type, | |
| node_group.node_cpps))[0] | |
| node_group.auxiliary = self._extract_cat_info( | |
| node_group, cpp_node) | |
| def find_predecessors(self, unique_name): | |
| """ | |
| Find predecessor node of the given node | |
| Parameters | |
| ---------- | |
| unique_name : str | |
| The unique name of the node | |
| Returns | |
| ------- | |
| list | |
| a list of nodes who are the given node's predecessor | |
| """ | |
| predecessors = [] | |
| for _input in self.name_to_node[unique_name].inputs: | |
| if not _input in self.output_to_node: | |
| _logger.debug("cannot find node with %s as its output", _input) | |
| else: | |
| node_py = self.output_to_node[_input] | |
| predecessors.append(node_py.unique_name) | |
| return predecessors | |
| def find_successors(self, unique_name): | |
| """ | |
| Find successor nodes of the given node | |
| Parameters | |
| ---------- | |
| unique_name : str | |
| The unique name of the node | |
| Returns | |
| ------- | |
| list | |
| a list of nodes who are the given node's successor | |
| """ | |
| successors = [] | |
| for output in self.name_to_node[unique_name].outputs: | |
| if output not in self.input_to_node: | |
| # may reach the output of the whole graph | |
| continue | |
| nodes_py = self.input_to_node[output] | |
| for node_py in nodes_py: | |
| successors.append(node_py.unique_name) | |
| return successors | |