1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from dataclasses import dataclass 8 9import torch 10from executorch.backends.example.example_operators.op_base import OpBase 11from executorch.backends.example.example_operators.utils import ( 12 _annotate_nodes, 13 _nodes_are_annotated, 14) 15 16 17def _annotate_add(partitions, quant_config): 18 """ 19 This is what the graph of a simple add op looks like: 20 add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None 21 """ 22 add_node = partitions[0].output_nodes[0] 23 add_input_1 = add_node.args[0] 24 add_input_2 = add_node.args[1] 25 26 if _nodes_are_annotated([add_node]): 27 return 28 29 _annotate_nodes( 30 [(add_node, add_input_1)], quant_config.input_quant_spec, input_node=True 31 ) 32 _annotate_nodes( 33 [(add_node, add_input_2)], quant_config.weight_quant_spec, input_node=True 34 ) 35 _annotate_nodes([(add_node,)], quant_config.output_quant_spec) 36 37 38@dataclass 39class AddNode(OpBase): 40 def __init__(self): 41 super().__init__( 42 pattern=(torch.add,), 43 annotate_handle=_annotate_add, 44 ) 45