xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pattern.md (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Fusion Pattern Format
2The patterns are we matching against is float modules types, functional operators and pytorch operators in reverse order:
3```
4operator = module_type | functional | torch op | native op | MatchAllNode
5Pattern = (operator, Pattern, Pattern, ...) | operator
6```
7where the first item for Pattern is the operator we want to match, and the rest are the patterns for the arguments of the operator.
8For example, pattern (nn.ReLU, (operator.add, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) would match the following graph:
9```
10tensor_1            tensor_2
11 |                    |
12 *(MatchAllNode)  nn.Conv2d
13 |                    |
14 |             nn.BatchNorm2d
15 \                  /
16  -- operator.add --
17         |
18      nn.ReLU
19```
20
21we’ll match the last node as the anchor point of the match, and we can retrieve the whole graph by tracing back from the node, e.g. in the example above, we matched nn.ReLU node, then node.args[0] is the operator.add node.
22