xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/subgraph_matcher.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 #include <unordered_map>
6 #include <vector>
7 
8 namespace torch::jit {
9 
10 /**
11  * \brief A structure describing a match of a pattern in a graph.
12  *
13  * The structure contains an anchor node, from which the match was found, and
14  * match-maps for nodes and values. A match-map specifies the correspondance
15  * between nodes in the pattern graph (match-map keys) with nodes in the actual
16  * graph (match-map values). We keep such maps for both nodes and values.
17  */
18 struct Match {
19   Node* anchor;
20   std::unordered_map<const Node*, Node*> nodes_map;
21   std::unordered_map<const Value*, Value*> values_map;
22 };
23 
24 /**
25  * \brief Find all matches of a \p PATTERN in a \p GRAPH.
26  *
27  * The function returns a vector of match-descriptors (see description of
28  * `struct Match`).
29  *
30  * Matching rules:
31  *  - Pattern graph must contain a single block.
32  *  - Matched subgraphs do not span across different blocks.
33  *  - No uses outside the match are allowed, except for Param and Return nodes.
34  *  Basically, we're matching hammocks, not arbitrary subgraphs.
35  *  - The pattern graph must return only one value (i.e. it must have a single
36  *  node leading to return).
37  *  - Nodes that are not used in computation of the return value in the pattern
38  * graph are ignored during matching (IOW, we're essentially performing DCE on
39  * the pattern).
40  *  - Pattern graph nodes cannot alias. TODO: the check not implemented yet.
41  *  - Aliasing nodes in the graph cannot consitute a match (i.e. through all
42  * found matches, no nodes in the subgraph alias with each other). TODO: check
43  * not implemented yet.
44  *  - The matcher will not mutate either the pattern graph or the matched graph.
45  * The matched graph is taken as non-const so that Match may contain non-const
46  * pointers.  This enables clients of this API to use Match to drive mutations.
47  *
48  * Note [Multi-output Patterns]
49  * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
50  * Subgraph matcher provides limited support for multi-output patterns. With a
51  * single output pattern, a single scan through the graph is sufficient to
52  * find all the matches: given a starting node (an "anchor"), we can
53  * deterministically check whether a pattern matches a subgraph corresponding to
54  * this anchor node. For a general case of multi-output patterns, we would have
55  * N anchors, which would result in M^N comparisons (M is the size of the
56  * graph). Clearly this is computationally prohibitive.
57  *
58  * To overcome this, we impose some constraints on the multi-output patterns
59  * that we accept. We require that checking whether the pattern matches a
60  * subgraph would still be fully determined by a single node in the graph. To
61  * achieve this, we designate the first output in the pattern as the "main"
62  * output and assume that we can traverse up from this node to match the
63  * entire pattern.
64  *
65  * Corrolary 1: the order of outputs in the pattern matters!
66  * Corollary 2: patterns cannot contain any nodes not participating in the main
67  * output computation.
68  */
69 std::vector<Match> TORCH_API
70 findPatternMatches(const Graph& pattern, Graph& graph);
71 
72 } // namespace torch::jit
73