xref: /aosp_15_r20/external/executorch/backends/example/example_operators/flatten.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from dataclasses import dataclass
8
9import torch
10from executorch.backends.example.example_operators.op_base import OpBase
11from executorch.backends.example.example_operators.utils import (
12    _annotate_nodes,
13    _nodes_are_annotated,
14)
15
16
17def _annotate_flatten(partitions, quant_config):
18    """
19    This is what the graph of a simple add op looks like:
20    add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1);  arg0_1 = arg1_1 = None
21    """
22    flatten_node = partitions[0].output_nodes[0]
23    flatten_input = flatten_node.args[0]
24
25    if _nodes_are_annotated([flatten_node]):
26        return
27
28    _annotate_nodes(
29        [(flatten_node, flatten_input)], quant_config.input_quant_spec, input_node=True
30    )
31    _annotate_nodes([(flatten_node,)], quant_config.output_quant_spec)
32
33
34@dataclass
35class FlattenNode(OpBase):
36    def __init__(self):
37        super().__init__(
38            pattern=(torch.flatten,),
39            annotate_handle=_annotate_flatten,
40        )
41