xref: /aosp_15_r20/external/executorch/backends/arm/_passes/meandim_to_averagepool_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from typing import Any, cast, Dict, Tuple
10
11import torch.fx
12
13from executorch.exir.dialects._ops import ops as exir_ops
14from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
15
16Argument = Any
17
18
19class ConvertMeanDimToAveragePool(ExportPass):
20    """
21    Replace a mean operation with dim = [-1, -2] and keep_dim = True with an average pool operation.
22    """
23
24    def call_operator(
25        self,
26        op: torch.fx.node.Target,
27        args: Tuple[Argument, ...],
28        kwargs: Dict[str, Argument],
29        meta: NodeMetadata,
30    ) -> ProxyValue:
31        if op != exir_ops.edge.aten.mean.dim:
32            return super().call_operator(op, args, kwargs, meta)
33
34        input_value = cast(ProxyValue, args[0])
35        dim = cast(list, args[1])
36        keep_dim = cast(bool, args[2]) if len(args) > 2 else False
37
38        # averagepool2d gets converted to a mean operation with dim = [-1, -2] and keep_dim = True
39        # so check the dim argument for this case
40        if dim == [-1, -2] and keep_dim is True:
41            # Given the shape format of input is (N, C, H, W)
42            kernel_size = [
43                input_value.to_tensor().size()[2],
44                input_value.to_tensor().size()[3],
45            ]
46            stride = [1, 1]
47            return super().call_operator(
48                exir_ops.edge.aten.avg_pool2d.default,
49                (input_value, kernel_size, stride),
50                {},
51                meta,
52            )
53        else:
54            return super().call_operator(op, args, kwargs, meta)
55