1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from executorch.exir.dialects._ops import ops as exir_ops 8 9from executorch.exir.pass_base import ExportPass 10 11 12class MeanToSumDiv(ExportPass): 13 def call_operator(self, op, args, kwargs, meta): 14 if op != exir_ops.edge.aten.mean.dim: 15 return super().call_operator(op, args, kwargs, meta) 16 sum_res = super().call_operator( 17 exir_ops.edge.aten.sum.dim_IntList, args, kwargs, meta 18 ) 19 # args[0] is the input tensor 20 shape = args[0].node.meta["val"].shape 21 dtype = args[0].node.meta["val"].dtype 22 dims_to_reduce = args[1] 23 size = 1.0 24 for dim in dims_to_reduce: 25 size = size * shape[dim] 26 27 size_tensor = super().call_operator( 28 exir_ops.edge.aten.full.default, 29 ( 30 [ 31 1, 32 ], 33 size, 34 ), 35 {"dtype": dtype}, 36 meta, 37 ) 38 39 return super().call_operator( 40 exir_ops.edge.aten.div.Tensor, (sum_res, size_tensor), {}, meta 41 ) 42