1 #pragma once 2 // This file is temporary until native_functions.yaml and derivatives.yaml are 3 // merged. Ideally this should all go into native_functions.yaml 4 5 #include <torch/csrc/Export.h> 6 #include <torch/csrc/jit/ir/ir.h> 7 8 namespace torch::jit { 9 10 /* 11 ADDING A NEW SHAPE GRAPH: 12 - For one node schema, there is one corresponding registered shape compute 13 graph. The schema of the graph should be the same except for Tensor arguments. 14 For every Tensor input in operator schema, there should be a List[int] 15 corresponding to that Tensor's shape. For example: "aten::linear(Tensor input, 16 Tensor weight, Tensor? bias=None) -> Tensor" ==> def linear(input: List[int], 17 weight: List[int], bias: Optional[List[int]]) 18 19 Additionally, arguments which are unused at the end of the schema may be left 20 off. This allows sharing a single graph for multiple function schemas, such as 21 unary operators with different trailing arguments that do not affect the output 22 shape. 23 24 The shape graph should return a new, unaliased List[int] (or tuple of lists for 25 multiple returns) and should not modify any input lists. This allows the shape 26 graphs to be composed and executed. 27 28 The shape analysis (particularly for non-complete, or symbolic shapes) works by 29 partially evaluating the JIT IR. It may be possible for a Graph to be registered 30 that we cannot currently partially evaluate. If this happens, please file an 31 issue. There are lints registered to avoid particular known patterns (continue 32 or break or early return in a loop). Those may be improved in the future, please 33 file an issue if necessary. 34 35 To debug (and write initially) the recommended flow is to define these functions 36 in python and iterate there. Functions should be added to 37 torch/jit/_shape_functions. 38 39 To test operators, the preferred flow is through OpInfos, with 40 `assert_jit_shape_analysis=True`. If this is not feasible, you can look at tests 41 in `test_symbolic_shape_analysis.py` such as `test_adaptive_avg_pool2d`. 42 43 Operators which take in a list of tensors, such as concat, are not yet 44 supported. Concat has been special cased and could be generalized as needed. 45 Please file an issue. 46 */ 47 48 struct BoundedShapeGraphs { 49 std::shared_ptr<Graph> lower_bound; 50 std::shared_ptr<Graph> upper_bound; 51 }; 52 53 TORCH_API void RegisterShapeComputeGraphForSchema( 54 const FunctionSchema& schema, 55 const std::shared_ptr<Graph>& g); 56 57 TORCH_API std::optional<std::shared_ptr<Graph>> shapeComputeGraphForSchema( 58 const FunctionSchema& schema); 59 60 TORCH_API std::optional<BoundedShapeGraphs> boundedGraphsForSchema( 61 const FunctionSchema& schema); 62 63 TORCH_API std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas(); 64 65 TORCH_API void LintShapeComputeGraph( 66 const FunctionSchema* schema, 67 const std::shared_ptr<Graph>& graph); 68 69 } // namespace torch::jit 70