1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from dataclasses import dataclass 8 9import torch 10from executorch.backends.example.example_operators.op_base import OpBase 11from executorch.backends.example.example_operators.utils import ( 12 _annotate_nodes, 13 _nodes_are_annotated, 14) 15 16 17def _annotate_flatten(partitions, quant_config): 18 """ 19 This is what the graph of a simple add op looks like: 20 add_tensor = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None 21 """ 22 flatten_node = partitions[0].output_nodes[0] 23 flatten_input = flatten_node.args[0] 24 25 if _nodes_are_annotated([flatten_node]): 26 return 27 28 _annotate_nodes( 29 [(flatten_node, flatten_input)], quant_config.input_quant_spec, input_node=True 30 ) 31 _annotate_nodes([(flatten_node,)], quant_config.output_quant_spec) 32 33 34@dataclass 35class FlattenNode(OpBase): 36 def __init__(self): 37 super().__init__( 38 pattern=(torch.flatten,), 39 annotate_handle=_annotate_flatten, 40 ) 41