xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
6 
7 namespace torch::jit {
8 
9 // Takes in a TensorExprGraph of static shapes and generalizes the input shapes
10 // to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise
11 // dimensions with the same value will be bucketed to the same symbolic shape.
12 // E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
13 // From there, runs symbolic shape inference on the graph, and creates a
14 // versioning if in the graph with prim::TensorExprDynamicGuard checking if
15 // the inputs at runtime match the Generalized Symbolic Shapes that are inputs
16 // to the TE Kernel. The computate to calculate all symbolic dimensions is
17 // inlined in to the if block with the TE Kernel. All Sym Dim Value* are
18 // appended to the end of the TE Kernel Graph/Node inputs, and the Node is
19 // augmented with a integer list attr `symbolic_shape_inputs` that gives the
20 // mapping from Value * -> Symbolic Shape int64_t value. For more lengthy IR
21 // examples and walkthrough look at ShapeAnalysisTest.DynamicShapesFusion in
22 // `test_shape_analysis` Returns True on Success, False on Failure, can fail if
23 // shape propagation fails to propagate # of dims or if complete shapes on
24 // inputs not set
25 
26 TORCH_API bool GenerateGuard(
27     Node* tensorexpr_graph_node,
28     bool add_composed_op = false);
29 
30 TORCH_API void runTensorExprDynamicGroup(const Code& code, Stack& stack);
31 
32 enum class StrideInput {
33   // Tensors natively store whether they are contiguous or not as a property
34   // this makes it faster to query `is_contiguous` or
35   // `is_contiguous(memory_format=channels_last)`
36   // than looping through the sizes/strides yourself
37   // For tensors with these properties, we only store one value:
38   TENSOR_CONT,
39   TENSOR_CONT_CHANNELS_LAST,
40   // now, we describe other cases, where there is one stride enum
41   // per dimension
42   S_ONE, // STRIDE_ONE: packed
43   S_CONT, // STRIDE_CONTIGUOUS: stride[i + 1] * sizes[i + 1]
44   S_TRAN_CONT, // STRIDE_TRANSPOSED_CONTIGUOUS: stride[i-1] * sizes[i-1]
45   S_AS_ARG, // STRIDE_AS_ARG: stride passed in as runtime value
46 };
47 
48 TORCH_API std::string toString(StrideInput si);
49 TORCH_API StrideInput strideInputFromString(const std::string& si);
50 
51 } // namespace torch::jit
52