xref: /aosp_15_r20/external/executorch/exir/backend/test/test_passes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch import exir
11from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import (
12    duplicate_constant_node,
13)
14from torch._export.utils import is_buffer
15from torch.export import export_for_training
16from torch.testing import FileCheck
17
18
19class TestPasses(unittest.TestCase):
20    def test_duplicate_constant_node_pass(self):
21
22        class ReuseConstData(torch.nn.Module):
23            def __init__(self):
24                super().__init__()
25                self.register_buffer("const", torch.ones(2, 2))
26
27            def forward(self, x):
28                y = x + self.const
29                z = x - self.const
30                return y, z
31
32        model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module()
33        edge = exir.to_edge(torch.export.export(model, (torch.ones(2, 2),)))
34
35        const_nodes = [
36            node.name
37            for node in edge.exported_program().graph.nodes
38            if node.op == "placeholder" and is_buffer(edge.exported_program(), node)
39        ]
40
41        copied_nodes = duplicate_constant_node(edge.exported_program(), const_nodes[0])
42        self.assertEqual(len(copied_nodes), 1)
43
44        # Check that the new constant node is in the graph
45        FileCheck().check("b_const_copy_0").run(
46            edge.exported_program().graph_module.code
47        )
48