1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3 4import executorch.backends.cadence.aot.ops_registrations # noqa 5import torch 6from executorch.backends.cadence.aot.graph_builder import ( 7 GraphBuilder, 8 single_op_builder, 9) 10from executorch.backends.cadence.aot.pass_utils import count_node 11from executorch.exir.dialects._ops import ops as exir_ops 12from executorch.exir.pass_base import ExportPass 13from later.unittest import TestCase 14 15 16class TestGraphBuilder(TestCase): 17 def test_graph_with_single_im2row(self) -> None: 18 # Create a graph with a single im2row node. 19 builder = GraphBuilder() 20 x = builder.placeholder("x", torch.randn(1, 3, 224, 224)) 21 pad_value = builder.placeholder("pad", torch.randn(1)) 22 channels_last = False 23 im2row = builder.call_operator( 24 exir_ops.edge.cadence.im2row.default, 25 # pyre-ignore 26 ( 27 x, 28 (2, 2), 29 (1, 1), 30 (0, 0), 31 (1, 1), 32 pad_value, 33 channels_last, 34 ), 35 ) 36 builder.output([im2row]) 37 gm = builder.get_graph_module() 38 # Check if graph module is valid by running exportpass on it. 39 gm = ExportPass().call(gm).graph_module 40 41 # Check graph has a single im2row node. 42 self.assertEqual(len([gm.graph.nodes]), 1) 43 self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) 44 45 46class TestSingleOpBuilderUtility(TestCase): 47 def test_graph_with_single_im2row(self) -> None: 48 # Create a graph with a single im2row node. 49 x = torch.randn(1, 3, 224, 224) 50 pad_value = torch.randn(1) 51 channels_last = False 52 gm = single_op_builder( 53 (x, pad_value), 54 exir_ops.edge.cadence.im2row.default, 55 ( 56 x, 57 (2, 2), 58 (1, 1), 59 (0, 0), 60 (1, 1), 61 pad_value, 62 channels_last, 63 ), 64 ) 65 # Check if graph module is valid by running exportpass on it. 66 gm = ExportPass().call(gm).graph_module 67 68 # Check graph has a single im2row node. 69 self.assertEqual(len([gm.graph.nodes]), 1) 70 self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) 71