1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict, Tuple 7 8import torch 9from executorch.exir.pass_base import ExportPass 10from torch._export.pass_base import Argument 11from torch._export.pass_infra.node_metadata import NodeMetadata 12from torch._export.pass_infra.proxy_value import ProxyValue 13 14 15class ConvertBinaryOpsWithScalar(ExportPass): 16 """ 17 Replace binary ops with scalar into binary ops with tensor. 18 Since torch.ops.aten.xxx.Scalar will not generate a placeholder node 19 for scalar after to_edge. 20 """ 21 22 binary_ops_with_scalar = { 23 torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, 24 torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, 25 torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor, 26 torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, 27 } 28 29 def __init__(self): 30 super(ConvertBinaryOpsWithScalar, self).__init__() 31 32 def call_operator( 33 self, 34 op, 35 args: Tuple[Argument, ...], 36 kwargs: Dict[str, Argument], 37 meta: NodeMetadata, 38 ) -> ProxyValue: 39 return super().call_operator( 40 self.binary_ops_with_scalar.get(op, op), args, kwargs, meta 41 ) 42