1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 #include <unordered_map> 6 #include <utility> 7 #include <variant> 8 9 namespace torch::jit { 10 11 // CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE 12 13 TORCH_API void PropagateShapesOnGraph(std::shared_ptr<Graph>& graph); 14 15 // CAUTION NOT TO BE USED, STILL A WIP, NOT STABLE 16 // From [beg, end) attempt to propagate shapes and 17 // build up a graph that will compute all remaining symbolic 18 // shapes in [beg, end) that can be executed before beg 19 20 struct ShapeComputeGraphMapping { ShapeComputeGraphMappingShapeComputeGraphMapping21 ShapeComputeGraphMapping( 22 std::shared_ptr<Graph> partial_eval_shape_graph, 23 std::unordered_map<Value*, Value*> 24 enclosing_graph_value_to_shape_graph_input, 25 std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim) 26 : partial_eval_shape_graph(std::move(partial_eval_shape_graph)), 27 enclosing_graph_value_to_shape_graph_input_( 28 std::move(enclosing_graph_value_to_shape_graph_input)), 29 graph_output_to_symbolic_shape_dim_( 30 std::move(graph_output_to_symbolic_shape_dim)){}; 31 32 std::shared_ptr<Graph> partial_eval_shape_graph; 33 std::unordered_map<Value*, Value*> 34 enclosing_graph_value_to_shape_graph_input_; 35 std::unordered_map<Value*, int64_t> graph_output_to_symbolic_shape_dim_; 36 }; 37 38 TORCH_API std::optional<ShapeComputeGraphMapping> 39 PropagateShapesAndBuildLargeShapeComputeGraph( 40 std::shared_ptr<Graph>& graph, 41 Node* beg, 42 Node* end); 43 44 // don't insert complete tensor shapes in shape compute graphs and instead 45 // rely on our partial evaluation pipeline to propagate information. 46 // this is a good proxy for our ability to propagate non-complete shape 47 // information. 48 TORCH_API bool setSymbolicShapeAnalysisTestMode(bool value); 49 TORCH_API bool symbolicShapeAnalysisTestModeEnabled(); 50 51 using SSAInput = std::variant<IValue, c10::SymbolicShape>; 52 TORCH_API std::optional<std::vector<c10::SymbolicShape>> 53 calculateSymbolicShapesOnOp( 54 const FunctionSchema* schema, 55 const std::vector<SSAInput>& inputs); 56 } // namespace torch::jit 57