xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/utils/subgraph_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/ir/alias_analysis.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 
7 namespace torch {
8 namespace jit {
9 
10 // Utilities for dealing with nodes that contain subgraphs.
11 //
12 // They handle the complexity of editing inputs/outputs as you merge nodes in
13 // and out of subgraphs.
14 namespace SubgraphUtils {
15 
16 // Create a new subgraph node that contains only `n`. The new subgraph will have
17 // `subgraphKind` as its type.
18 //
19 // `n` is destroyed.
20 //
21 // Returns the new subgraph node.
22 TORCH_API Node* createSingletonSubgraph(Node* n, Symbol subgraphKind);
23 
24 // Creates a new subgraph that only contains `n`, amd updates the new outputs
25 // of the subgraph to have the aliasing properties of the original `n` outputs
26 TORCH_API Node* createSingletonSubgraphAndUpdateAliasing(
27     Node* to_merge,
28     Symbol subgraphKind,
29     AliasDb& db);
30 
31 // Merge a node into a subgraph node. If `toMerge` is also a subgraph, the
32 // subgraphs are merged.
33 // If `destroyNode` is true `toMerge` is destroyed.
34 // An optional argument 'vmap' could be used to retrieve value mappings.
35 // Values will be mapped to their new subgraph values
36 TORCH_API void mergeNodeIntoSubgraph(
37     Node* toMerge,
38     Node* subgraphNode,
39     bool destroyNode = true);
40 
41 // Merges a node into a subgraph node, and updates the new outputs of the
42 // subgraph to have the aliasing properties of the corresponding `to_merge`
43 // outputs
44 TORCH_API void mergeNodeIntoSubgraphAndUpdateAliasing(
45     Node* to_merge,
46     Node* subgraphNode,
47     AliasDb& db);
48 
49 TORCH_API std::vector<Node*> unmergeAliasedOutputs(
50     Node* subgraphNode,
51     AliasDb& db);
52 
53 // Move nodes from a subgraph node to the outer graph.
54 // `subgraphNode` is destroyed.
55 TORCH_API void unmergeSubgraph(Node* subgraphNode);
56 
57 // Move `node_to_unmerge` and its descendants after `subgraphNode`
58 // promotes any dependencies of `node_to_unmerge` to subgraphNode outputs
59 TORCH_API void unmergeNode(Node* node_to_unmerge, Node* subgraphNode);
60 
61 TORCH_API bool unmergeOutputsAlisingInputs(Node* subgraphNode);
62 
63 TORCH_API bool unmergeAliasedOutputs(Node* subgraphNode);
64 
65 // Convenience function
66 std::shared_ptr<Graph> getSubgraph(Node* n);
67 
68 TORCH_API std::string generateNameForGraph(
69     const std::shared_ptr<Graph>& graph,
70     size_t maxlen = 40,
71     const std::string& prefix = "fused");
72 
73 } // namespace SubgraphUtils
74 } // namespace jit
75 } // namespace torch
76