xref: /aosp_15_r20/external/pytorch/torch/onnx/README.md (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# torch.onnx
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard WorkerTorch->ONNX converter / exporter.
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker- User-facing docs: https://pytorch.org/docs/main/onnx.html
6*da0073e9SAndroid Build Coastguard Worker- Developer docs: https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker> Read the following if you are contributing to `torch.onnx`
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker## Symbolic functions Opsets
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard WorkerOpset 9 is the base version. It is selected as the base version because
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker1. It is the first opset version supported by PyTorch export.
15*da0073e9SAndroid Build Coastguard Worker2. Opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
16*da0073e9SAndroid Build Coastguard Worker    that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
17*da0073e9SAndroid Build Coastguard Worker    we chose to handle them as special cases separately.
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard WorkerBackward support for opset versions beyond opset 7 is not in our roadmap.
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard WorkerFor opset versions other than 9, by default they will inherit the symbolic functions defined in
22*da0073e9SAndroid Build Coastguard Workersymbolic_opset9.py.
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard WorkerTo extend support for updated operators in different opset versions on top of opset 9,
25*da0073e9SAndroid Build Coastguard Workersimply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
26*da0073e9SAndroid Build Coastguard WorkerCheckout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker## Editing Symbolic Files
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker- Use the internal `registration.onnx_symbolic` decorator to register a new symbolic function. Search for `def reshape(g, self, shape):` to see an example.
31*da0073e9SAndroid Build Coastguard Worker- Parameter names must *exactly* match the names in
32*da0073e9SAndroid Build Coastguard Worker  aten/src/ATen/native/native_functions.yaml, because
33*da0073e9SAndroid Build Coastguard Worker  dispatch is done with keyword arguments.
34*da0073e9SAndroid Build Coastguard Worker- Looking for inplace ops? They're detected by
35*da0073e9SAndroid Build Coastguard Worker  `_jit_pass_onnx_remove_inplace_ops_for_onnx`, and
36*da0073e9SAndroid Build Coastguard Worker  transparently dispatched to their non inplace versions in
37*da0073e9SAndroid Build Coastguard Worker  "run_symbolic_function". See Note [Export inplace](#export-inplace)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker### A note on Tensor types
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard WorkerIn general, we should avoid depending on the type of Tensor Values contained
42*da0073e9SAndroid Build Coastguard Workerwithin the trace graph. However, this is sometimes unavoidable (due to ONNX
43*da0073e9SAndroid Build Coastguard Workerspec requirements, etc). The TensorType object has accessors for these properties that return the property if it is statically known and return nullopt otherwise.
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard WorkerIn general, we should prefer to rely on the least specific information possible.
46*da0073e9SAndroid Build Coastguard WorkerFor example, not relying on tensor properties at all is better than relying
47*da0073e9SAndroid Build Coastguard Workeron the number of dimensions which is better than relying on
48*da0073e9SAndroid Build Coastguard Workerconcrete shapes. Doing so will make the export symbolics
49*da0073e9SAndroid Build Coastguard Workermore robust to different graphs.
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker### Extra context for symbolic functions
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard WorkerThe first argument of a symbolic function is always a `GraphContext` object.
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker`GraphContext` contains all methods defined in a `torch.Graph` object and context
56*da0073e9SAndroid Build Coastguard Workerfor the symbolic function.
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard WorkerIn general, symbolic functions only require inputs and attributes to
59*da0073e9SAndroid Build Coastguard Workerthe original node. An example of a symbolic function needing context is
60*da0073e9SAndroid Build Coastguard Worker`prim::Loop`. It needs access to the sub-block of the original node.
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker### Export inplace
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard WorkerIt would be better for us to export inplace annotations,
65*da0073e9SAndroid Build Coastguard Workerthan to not export them, since it is useful information that can
66*da0073e9SAndroid Build Coastguard Workerhelp the target of an ONNX export export more efficiently. However,
67*da0073e9SAndroid Build Coastguard WorkerONNX doesn't currently formalize inplace. Fortunately, it's sound to drop
68*da0073e9SAndroid Build Coastguard Workerinplace annotations, but we are losing information this way.
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker### Pointwise by scalar
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard WorkerWhat happens if you add a tensor with a constant (e.g., x + 2)?  There are
73*da0073e9SAndroid Build Coastguard Workersome moving parts to implementing the ONNX translation in this case:
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker- By the time we get the scalar in a symbolic function here, it is no longer a
76*da0073e9SAndroid Build Coastguard Worker  Python long/float, but a PyTorch tensor with `numel == 1` (eventually, we want
77*da0073e9SAndroid Build Coastguard Worker  it to be a zero dim tensor but this change has not happened yet.) However, the
78*da0073e9SAndroid Build Coastguard Worker  type of this scalar is *exactly* what the user wrote in Python, which may not
79*da0073e9SAndroid Build Coastguard Worker  match the tensor it is being added to. PyTorch will do implicit conversions on
80*da0073e9SAndroid Build Coastguard Worker  scalars; however, ONNX will not, so we must do the conversion ourselves. This
81*da0073e9SAndroid Build Coastguard Worker  is what `symbolic_helper._if_scalar_type_as()` and
82*da0073e9SAndroid Build Coastguard Worker  `_jit_pass_onnx_scalar_type_analysis` does.
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker- Dispatch to these functions takes advantage an outrageous coincidence
85*da0073e9SAndroid Build Coastguard Worker    between the tensor and scalar name.  When we add two tensors together,
86*da0073e9SAndroid Build Coastguard Worker    you get the dispatch:
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    add(*[self, other], **{"alpha": alpha})
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    When you add a tensor and a scalar, you get the dispatch:
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Worker    add(*[self], **{"other": other, "alpha": alpha})
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    By having the argument name line up with the name of the scalar attribute
95*da0073e9SAndroid Build Coastguard Worker    if it exists, we can write a single function for both overloads.
96