1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from dataclasses import dataclass 8 9import torch 10from executorch.backends.example.example_operators.op_base import OpBase 11from executorch.backends.example.example_operators.utils import ( 12 _annotate_nodes, 13 _nodes_are_annotated, 14) 15 16 17def _annotate_conv2d(partitions, quant_config): 18 """ 19 This is what the graph of a simple conv op looks like: 20 l__self___conv_weight = self.L__self___conv_weight 21 l__self___conv_bias = self.L__self___conv_bias 22 convolution_default = torch.ops.aten.convolution.default(arg2_1, l__self___conv_weight, l__self___conv_bias, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); arg2_1 = l__self___conv_weight = l__self___conv_bias = None 23 """ 24 conv_node = partitions[0].output_nodes[0] 25 input_node = conv_node.args[0] 26 weight_node = conv_node.args[1] 27 28 if _nodes_are_annotated([conv_node]): 29 return 30 31 _annotate_nodes( 32 [(conv_node, input_node)], quant_config.input_quant_spec, input_node=True 33 ) 34 _annotate_nodes( 35 [(conv_node, weight_node)], quant_config.weight_quant_spec, input_node=True 36 ) 37 _annotate_nodes([(conv_node,)], quant_config.output_quant_spec) 38 39 40# def _permuate_memory_format_pass(exported_program, partitions): 41# print(" _permuate_memory_format_pass starting...") 42# return exported_program 43 44 45@dataclass 46class Conv2DNode(OpBase): 47 def __init__(self): 48 super().__init__( 49 pattern=(torch.nn.Conv2d,), 50 annotate_handle=_annotate_conv2d, 51 permuate_memory_format=True, 52 ) 53