1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from typing import Any, cast, Dict, Tuple 10 11import torch.fx 12 13from executorch.exir.dialects._ops import ops as exir_ops 14from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue 15 16Argument = Any 17 18 19class ConvertMeanDimToAveragePool(ExportPass): 20 """ 21 Replace a mean operation with dim = [-1, -2] and keep_dim = True with an average pool operation. 22 """ 23 24 def call_operator( 25 self, 26 op: torch.fx.node.Target, 27 args: Tuple[Argument, ...], 28 kwargs: Dict[str, Argument], 29 meta: NodeMetadata, 30 ) -> ProxyValue: 31 if op != exir_ops.edge.aten.mean.dim: 32 return super().call_operator(op, args, kwargs, meta) 33 34 input_value = cast(ProxyValue, args[0]) 35 dim = cast(list, args[1]) 36 keep_dim = cast(bool, args[2]) if len(args) > 2 else False 37 38 # averagepool2d gets converted to a mean operation with dim = [-1, -2] and keep_dim = True 39 # so check the dim argument for this case 40 if dim == [-1, -2] and keep_dim is True: 41 # Given the shape format of input is (N, C, H, W) 42 kernel_size = [ 43 input_value.to_tensor().size()[2], 44 input_value.to_tensor().size()[3], 45 ] 46 stride = [1, 1] 47 return super().call_operator( 48 exir_ops.edge.aten.avg_pool2d.default, 49 (input_value, kernel_size, stride), 50 {}, 51 meta, 52 ) 53 else: 54 return super().call_operator(op, args, kwargs, meta) 55