xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/tensorexpr/graph_opt.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit::tensorexpr {
6 
7 // Optimize aten::cat ops in the given subgraph.
8 //
9 // Moving users of cat to its inputs.
10 //    Cat ops get lowered into multiple loops, one per input. When the result
11 //    of cat is used by some other op, it results in a situation where inlining
12 //    of cat does not happen. This in turn results in intermediate buffers
13 //    being created for the result of cat, since it is not inlined.
14 //
15 //    For example, consider the following graph:
16 //       graph(%x : Float(10, strides=[1], device=cpu),
17 //             %y : Float(20, strides=[1], device=cpu)):
18 //         %dim : int = prim::Constant[value=0]()
19 //         %xy_list : Tensor[] = prim::ListConstruct(%x, %y)
20 //         %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
21 //         %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat)
22 //         return (%5))IR";
23 //
24 //     This will get lowered into:
25 //         Allocate(aten_cat);
26 //         for (...)
27 //           aten_cat[...] = x[...]
28 //         for (...)
29 //           aten_cat[...] = y[...]
30 //         for (...)
31 //           aten_log[...] = log(aten_cat[...])
32 //         Free(aten_cat);
33 //     Note that aten_cat is not inlined into aten_log and it results in
34 //     an intermediate buffer allocation as well.
35 //
36 //     Optimization:
37 //        We move the ops that use the result of `cat` into its inputs whenever
38 //     possible.
39 //
40 //     The graph above will be transformed to:
41 //        graph(%x : Float(10, strides=[1], device=cpu),
42 //              %y : Float(20, strides=[1], device=cpu)):
43 //          %3 : int = prim::Constant[value=0]()
44 //          %7 : Float(10, strides=[1], device=cpu) = aten::log(%x)
45 //          %8 : Float(20, strides=[1], device=cpu) = aten::log(%y)
46 //          %9 : Tensor[] = prim::ListConstruct(%7, %8)
47 //          %10 : Float(60, strides=[1], device=cpu) = aten::cat(%9, %3)
48 //          return (%10)
49 //
50 //     This will get lowered into:
51 //         for (...)
52 //           aten_cat[...] = log(x[...])
53 //         for (...)
54 //           aten_cat[...] = log(y[...])
55 //     aten_cat is the output buffer here.
56 
57 bool OptimizeCat(const std::shared_ptr<Graph>& graph);
58 
59 TORCH_API void annotateInputShapes(
60     const std::shared_ptr<Graph>& graph,
61     const std::vector<std::optional<at::Tensor>>& example_inputs);
62 TORCH_API std::shared_ptr<Graph> removeUnusedSelfArgument(
63     const std::shared_ptr<Graph>& graph);
64 TORCH_API std::shared_ptr<Graph> removeGraphOutput(
65     const std::shared_ptr<Graph>& graph,
66     size_t idx);
67 TORCH_API std::shared_ptr<Graph> replaceListOutputWithTuple(
68     const std::shared_ptr<Graph>& graph);
69 
70 // Perform \p ITERS rounds of "trimming" for the given \p GRAPH.
71 //
72 // Trimming means that we try to remove a small portion of the graph while
73 // keeping it valid. This is useful for debugging when we try to find a minimal
74 // example reproducing the issue at hand. When ITERS is 0, the graph remains
75 // unchanged, when ITERS is a big number, the graph usually becomes empty.
76 TORCH_API std::shared_ptr<Graph> trimGraph(
77     const std::shared_ptr<Graph>& graph,
78     int64_t iters);
79 
80 // Scan all values in the given graph and replace each dimension with a size Xi
81 // present in \p SIZES with a symbolic shape Yi. Return a vector of symbol
82 // values [Y0, Y1, .., Yn].
83 //
84 // For example:
85 // Input:
86 // graph(%x : Float(10, 20, 30, 40)):
87 //   %y : Float(10, 20, 30, 40) = aten::relu(%x)
88 //   return %y
89 //
90 // If we run makeShapesSymbolic(graph, {20, 40}), then we'll get:
91 //
92 // graph(%x : Float(10, SS(-3), 30, SS(-5))):
93 //   %y : Float(10, SS(-3), 30, SS(-5)) = aten::relu(%x)
94 //   return %y
95 //
96 // and get {-3, -5} as the return value.
97 TORCH_API std::vector<int64_t> makeShapesSymbolic(
98     std::shared_ptr<Graph>& graph,
99     const std::vector<int64_t>& sizes);
100 
101 // Inspect the graph and report whether it can be converted to TE IR.
102 // TODO: add error reporting for graphs that can't be converted.
103 TORCH_API bool isGraphCompilable(const std::shared_ptr<Graph>& graph);
104 
105 // Examine the graph and (hackily) fill in missing tensor type info, such as
106 // scalar type, device, and strides. Ideally, this should be done by a proper
107 // dtype/device/shape propagation passes, but until they are ready we can use
108 // this, not always correct, workaround pass.
109 TORCH_API void fixupMissingShapeInfo(const std::shared_ptr<Graph>& graph);
110 
111 } // namespace torch::jit::tensorexpr
112