xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/symbolic_shape_analysis.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <unordered_map>
6 #include <utility>
7 #include <variant>
8 
9 namespace torch::jit {
10 
11 // CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE
12 
13 TORCH_API void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph);
14 
15 // CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE
16 // From [beg, end) attempt to propagate shapes and
17 // build up a graph that will compute all remaining symbolic
18 // shapes in [beg, end) that can be executed before beg
19 
20 struct ShapeComputeGraphMapping {
ShapeComputeGraphMappingShapeComputeGraphMapping21   ShapeComputeGraphMapping(
22       std::shared_ptr<Graph> partial_eval_shape_graph,
23       std::unordered_map<Value*, Value*>
24           enclosing_graph_value_to_shape_graph_input,
25       std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim)
26       : partial_eval_shape_graph(std::move(partial_eval_shape_graph)),
27         enclosing_graph_value_to_shape_graph_input_(
28             std::move(enclosing_graph_value_to_shape_graph_input)),
29         graph_output_to_symbolic_shape_dim_(
30             std::move(graph_output_to_symbolic_shape_dim)){};
31 
32   std::shared_ptr<Graph> partial_eval_shape_graph;
33   std::unordered_map<Value*, Value*>
34       enclosing_graph_value_to_shape_graph_input_;
35   std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim_;
36 };
37 
38 TORCH_API std::optional<ShapeComputeGraphMapping>
39 PropagateShapesAndBuildLargeShapeComputeGraph(
40     std::shared_ptr<Graph>& graph,
41     Node* beg,
42     Node* end);
43 
44 // don't insert complete tensor shapes in shape compute graphs and instead
45 // rely on our partial evaluation pipeline to propagate information.
46 // this is a good proxy for our ability to propagate non-complete shape
47 // information.
48 TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value);
49 TORCH_API bool symbolicShapeAnalysisTestModeEnabled();
50 
51 using SSAInput = std::variant<IValue, c10::SymbolicShape>;
52 TORCH_API std::optional<std::vector<c10::SymbolicShape>>
53 calculateSymbolicShapesOnOp(
54     const FunctionSchema* schema,
55     const std::vector<SSAInput>& inputs);
56 } // namespace torch::jit
57