xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/symbolic_shape_registry.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 // This file is temporary until native_functions.yaml and derivatives.yaml are
3 // merged. Ideally this should all go into native_functions.yaml
4 
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 
8 namespace torch::jit {
9 
10 /*
11 ADDING A NEW SHAPE GRAPH:
12 - For one node schema, there is one corresponding registered shape compute
13 graph. The schema of the graph should be the same except for Tensor arguments.
14 For every Tensor input in operator schema, there should be a List[int]
15 corresponding to that Tensor's shape. For example: "aten::linear(Tensor input,
16 Tensor weight, Tensor? bias=None) -> Tensor" ==> def linear(input: List[int],
17 weight: List[int], bias: Optional[List[int]])
18 
19 Additionally, arguments which are unused at the end of the schema may be left
20 off. This allows sharing a single graph for multiple function schemas, such as
21 unary operators with different trailing arguments that do not affect the output
22 shape.
23 
24 The shape graph should return a new, unaliased List[int] (or tuple of lists for
25 multiple returns) and should not modify any input lists. This allows the shape
26 graphs to be composed and executed.
27 
28 The shape analysis (particularly for non-complete, or symbolic shapes) works by
29 partially evaluating the JIT IR. It may be possible for a Graph to be registered
30 that we cannot currently partially evaluate. If this happens, please file an
31 issue. There are lints registered to avoid particular known patterns (continue
32 or break or early return in a loop). Those may be improved in the future, please
33 file an issue if necessary.
34 
35 To debug (and write initially) the recommended flow is to define these functions
36 in python and iterate there. Functions should be added to
37 torch/jit/_shape_functions.
38 
39 To test operators, the preferred flow is through OpInfos, with
40 `assert_jit_shape_analysis=True`. If this is not feasible, you can look at tests
41 in `test_symbolic_shape_analysis.py` such as `test_adaptive_avg_pool2d`.
42 
43 Operators which take in a list of tensors, such as concat, are not yet
44 supported. Concat has been special cased and could be generalized as needed.
45 Please file an issue.
46 */
47 
48 struct BoundedShapeGraphs {
49   std::shared_ptr<Graph> lower_bound;
50   std::shared_ptr<Graph> upper_bound;
51 };
52 
53 TORCH_API void RegisterShapeComputeGraphForSchema(
54     const FunctionSchema& schema,
55     const std::shared_ptr<Graph>& g);
56 
57 TORCH_API std::optional<std::shared_ptr<Graph>> shapeComputeGraphForSchema(
58     const FunctionSchema& schema);
59 
60 TORCH_API std::optional<BoundedShapeGraphs> boundedGraphsForSchema(
61     const FunctionSchema& schema);
62 
63 TORCH_API std::vector<const FunctionSchema*> RegisteredShapeComputeSchemas();
64 
65 TORCH_API void LintShapeComputeGraph(
66     const FunctionSchema* schema,
67     const std::shared_ptr<Graph>& graph);
68 
69 } // namespace torch::jit
70