1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch import exir 11from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import ( 12 duplicate_constant_node, 13) 14from torch._export.utils import is_buffer 15from torch.export import export_for_training 16from torch.testing import FileCheck 17 18 19class TestPasses(unittest.TestCase): 20 def test_duplicate_constant_node_pass(self): 21 22 class ReuseConstData(torch.nn.Module): 23 def __init__(self): 24 super().__init__() 25 self.register_buffer("const", torch.ones(2, 2)) 26 27 def forward(self, x): 28 y = x + self.const 29 z = x - self.const 30 return y, z 31 32 model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() 33 edge = exir.to_edge(torch.export.export(model, (torch.ones(2, 2),))) 34 35 const_nodes = [ 36 node.name 37 for node in edge.exported_program().graph.nodes 38 if node.op == "placeholder" and is_buffer(edge.exported_program(), node) 39 ] 40 41 copied_nodes = duplicate_constant_node(edge.exported_program(), const_nodes[0]) 42 self.assertEqual(len(copied_nodes), 1) 43 44 # Check that the new constant node is in the graph 45 FileCheck().check("b_const_copy_0").run( 46 edge.exported_program().graph_module.code 47 ) 48