1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/ir/alias_analysis.h> 5 #include <torch/csrc/jit/ir/ir.h> 6 7 namespace torch { 8 namespace jit { 9 10 // Utilities for dealing with nodes that contain subgraphs. 11 // 12 // They handle the complexity of editing inputs/outputs as you merge nodes in 13 // and out of subgraphs. 14 namespace SubgraphUtils { 15 16 // Create a new subgraph node that contains only `n`. The new subgraph will have 17 // `subgraphKind` as its type. 18 // 19 // `n` is destroyed. 20 // 21 // Returns the new subgraph node. 22 TORCH_API Node* createSingletonSubgraph(Node* n, Symbol subgraphKind); 23 24 // Creates a new subgraph that only contains `n`, amd updates the new outputs 25 // of the subgraph to have the aliasing properties of the original `n` outputs 26 TORCH_API Node* createSingletonSubgraphAndUpdateAliasing( 27 Node* to_merge, 28 Symbol subgraphKind, 29 AliasDb& db); 30 31 // Merge a node into a subgraph node. If `toMerge` is also a subgraph, the 32 // subgraphs are merged. 33 // If `destroyNode` is true `toMerge` is destroyed. 34 // An optional argument 'vmap' could be used to retrieve value mappings. 35 // Values will be mapped to their new subgraph values 36 TORCH_API void mergeNodeIntoSubgraph( 37 Node* toMerge, 38 Node* subgraphNode, 39 bool destroyNode = true); 40 41 // Merges a node into a subgraph node, and updates the new outputs of the 42 // subgraph to have the aliasing properties of the corresponding `to_merge` 43 // outputs 44 TORCH_API void mergeNodeIntoSubgraphAndUpdateAliasing( 45 Node* to_merge, 46 Node* subgraphNode, 47 AliasDb& db); 48 49 TORCH_API std::vector<Node*> unmergeAliasedOutputs( 50 Node* subgraphNode, 51 AliasDb& db); 52 53 // Move nodes from a subgraph node to the outer graph. 54 // `subgraphNode` is destroyed. 55 TORCH_API void unmergeSubgraph(Node* subgraphNode); 56 57 // Move `node_to_unmerge` and its descendants after `subgraphNode` 58 // promotes any dependencies of `node_to_unmerge` to subgraphNode outputs 59 TORCH_API void unmergeNode(Node* node_to_unmerge, Node* subgraphNode); 60 61 TORCH_API bool unmergeOutputsAlisingInputs(Node* subgraphNode); 62 63 TORCH_API bool unmergeAliasedOutputs(Node* subgraphNode); 64 65 // Convenience function 66 std::shared_ptr<Graph> getSubgraph(Node* n); 67 68 TORCH_API std::string generateNameForGraph( 69 const std::shared_ptr<Graph>& graph, 70 size_t maxlen = 40, 71 const std::string& prefix = "fused"); 72 73 } // namespace SubgraphUtils 74 } // namespace jit 75 } // namespace torch 76