xref: /aosp_15_r20/external/executorch/backends/cadence/aot/tests/test_graph_builder.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3
4import executorch.backends.cadence.aot.ops_registrations  # noqa
5import torch
6from executorch.backends.cadence.aot.graph_builder import (
7    GraphBuilder,
8    single_op_builder,
9)
10from executorch.backends.cadence.aot.pass_utils import count_node
11from executorch.exir.dialects._ops import ops as exir_ops
12from executorch.exir.pass_base import ExportPass
13from later.unittest import TestCase
14
15
16class TestGraphBuilder(TestCase):
17    def test_graph_with_single_im2row(self) -> None:
18        # Create a graph with a single im2row node.
19        builder = GraphBuilder()
20        x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
21        pad_value = builder.placeholder("pad", torch.randn(1))
22        channels_last = False
23        im2row = builder.call_operator(
24            exir_ops.edge.cadence.im2row.default,
25            # pyre-ignore
26            (
27                x,
28                (2, 2),
29                (1, 1),
30                (0, 0),
31                (1, 1),
32                pad_value,
33                channels_last,
34            ),
35        )
36        builder.output([im2row])
37        gm = builder.get_graph_module()
38        # Check if graph module is valid by running exportpass on it.
39        gm = ExportPass().call(gm).graph_module
40
41        # Check graph has a single im2row node.
42        self.assertEqual(len([gm.graph.nodes]), 1)
43        self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
44
45
46class TestSingleOpBuilderUtility(TestCase):
47    def test_graph_with_single_im2row(self) -> None:
48        # Create a graph with a single im2row node.
49        x = torch.randn(1, 3, 224, 224)
50        pad_value = torch.randn(1)
51        channels_last = False
52        gm = single_op_builder(
53            (x, pad_value),
54            exir_ops.edge.cadence.im2row.default,
55            (
56                x,
57                (2, 2),
58                (1, 1),
59                (0, 0),
60                (1, 1),
61                pad_value,
62                channels_last,
63            ),
64        )
65        # Check if graph module is valid by running exportpass on it.
66        gm = ExportPass().call(gm).graph_module
67
68        # Check graph has a single im2row node.
69        self.assertEqual(len([gm.graph.nodes]), 1)
70        self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
71