1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 namespace torch::jit::tensorexpr { 6 7 // Optimize aten::cat ops in the given subgraph. 8 // 9 // Moving users of cat to its inputs. 10 // Cat ops get lowered into multiple loops, one per input. When the result 11 // of cat is used by some other op, it results in a situation where inlining 12 // of cat does not happen. This in turn results in intermediate buffers 13 // being created for the result of cat, since it is not inlined. 14 // 15 // For example, consider the following graph: 16 // graph(%x : Float(10, strides=[1], device=cpu), 17 // %y : Float(20, strides=[1], device=cpu)): 18 // %dim : int = prim::Constant[value=0]() 19 // %xy_list : Tensor[] = prim::ListConstruct(%x, %y) 20 // %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim) 21 // %5 : Float(60, strides=[1], device=cpu) = aten::log(%cat) 22 // return (%5))IR"; 23 // 24 // This will get lowered into: 25 // Allocate(aten_cat); 26 // for (...) 27 // aten_cat[...] = x[...] 28 // for (...) 29 // aten_cat[...] = y[...] 30 // for (...) 31 // aten_log[...] = log(aten_cat[...]) 32 // Free(aten_cat); 33 // Note that aten_cat is not inlined into aten_log and it results in 34 // an intermediate buffer allocation as well. 35 // 36 // Optimization: 37 // We move the ops that use the result of `cat` into its inputs whenever 38 // possible. 39 // 40 // The graph above will be transformed to: 41 // graph(%x : Float(10, strides=[1], device=cpu), 42 // %y : Float(20, strides=[1], device=cpu)): 43 // %3 : int = prim::Constant[value=0]() 44 // %7 : Float(10, strides=[1], device=cpu) = aten::log(%x) 45 // %8 : Float(20, strides=[1], device=cpu) = aten::log(%y) 46 // %9 : Tensor[] = prim::ListConstruct(%7, %8) 47 // %10 : Float(60, strides=[1], device=cpu) = aten::cat(%9, %3) 48 // return (%10) 49 // 50 // This will get lowered into: 51 // for (...) 52 // aten_cat[...] = log(x[...]) 53 // for (...) 54 // aten_cat[...] = log(y[...]) 55 // aten_cat is the output buffer here. 56 57 bool OptimizeCat(const std::shared_ptr<Graph>& graph); 58 59 TORCH_API void annotateInputShapes( 60 const std::shared_ptr<Graph>& graph, 61 const std::vector<std::optional<at::Tensor>>& example_inputs); 62 TORCH_API std::shared_ptr<Graph> removeUnusedSelfArgument( 63 const std::shared_ptr<Graph>& graph); 64 TORCH_API std::shared_ptr<Graph> removeGraphOutput( 65 const std::shared_ptr<Graph>& graph, 66 size_t idx); 67 TORCH_API std::shared_ptr<Graph> replaceListOutputWithTuple( 68 const std::shared_ptr<Graph>& graph); 69 70 // Perform \p ITERS rounds of "trimming" for the given \p GRAPH. 71 // 72 // Trimming means that we try to remove a small portion of the graph while 73 // keeping it valid. This is useful for debugging when we try to find a minimal 74 // example reproducing the issue at hand. When ITERS is 0, the graph remains 75 // unchanged, when ITERS is a big number, the graph usually becomes empty. 76 TORCH_API std::shared_ptr<Graph> trimGraph( 77 const std::shared_ptr<Graph>& graph, 78 int64_t iters); 79 80 // Scan all values in the given graph and replace each dimension with a size Xi 81 // present in \p SIZES with a symbolic shape Yi. Return a vector of symbol 82 // values [Y0, Y1, .., Yn]. 83 // 84 // For example: 85 // Input: 86 // graph(%x : Float(10, 20, 30, 40)): 87 // %y : Float(10, 20, 30, 40) = aten::relu(%x) 88 // return %y 89 // 90 // If we run makeShapesSymbolic(graph, {20, 40}), then we'll get: 91 // 92 // graph(%x : Float(10, SS(-3), 30, SS(-5))): 93 // %y : Float(10, SS(-3), 30, SS(-5)) = aten::relu(%x) 94 // return %y 95 // 96 // and get {-3, -5} as the return value. 97 TORCH_API std::vector<int64_t> makeShapesSymbolic( 98 std::shared_ptr<Graph>& graph, 99 const std::vector<int64_t>& sizes); 100 101 // Inspect the graph and report whether it can be converted to TE IR. 102 // TODO: add error reporting for graphs that can't be converted. 103 TORCH_API bool isGraphCompilable(const std::shared_ptr<Graph>& graph); 104 105 // Examine the graph and (hackily) fill in missing tensor type info, such as 106 // scalar type, device, and strides. Ideally, this should be done by a proper 107 // dtype/device/shape propagation passes, but until they are ready we can use 108 // this, not always correct, workaround pass. 109 TORCH_API void fixupMissingShapeInfo(const std::shared_ptr<Graph>& graph); 110 111 } // namespace torch::jit::tensorexpr 112