1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import pandas as pd 10 11import torch 12from executorch.devtools.backend_debug import DelegationBreakdown, get_delegation_info 13from executorch.exir import to_edge 14from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo 15from pandas.testing import assert_frame_equal 16 17 18class TestUtils(unittest.TestCase): 19 def test_get_delegation_info(self): 20 class Model(torch.nn.Module): 21 def __init__(self): 22 super().__init__() 23 24 def forward(self, a, x, b): 25 y = torch.mm(a, x) 26 z = y + b 27 a = z - a 28 y = torch.mm(a, x) 29 z = y + b 30 return z 31 32 m = Model() 33 inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2)) 34 edge = to_edge(torch.export.export(m, inputs)).to_backend( 35 AddMulPartitionerDemo() 36 ) 37 delegation_info = get_delegation_info(edge.exported_program().graph_module) 38 39 self.assertEqual(delegation_info.num_delegated_subgraphs, 2) 40 self.assertEqual(delegation_info.num_delegated_nodes, 4) 41 self.assertEqual(delegation_info.num_non_delegated_nodes, 3) 42 expected_delegation_by_op_dict = { 43 "aten_add_tensor": DelegationBreakdown( 44 op_type="aten_add_tensor", delegated=2, non_delegated=0 45 ), 46 "aten_mm_default": DelegationBreakdown( 47 op_type="aten_mm_default", delegated=2, non_delegated=0 48 ), 49 "aten_sub_tensor": DelegationBreakdown( 50 op_type="aten_sub_tensor", delegated=0, non_delegated=1 51 ), 52 "getitem": DelegationBreakdown( 53 op_type="getitem", delegated=0, non_delegated=2 54 ), 55 } 56 self.assertEqual( 57 delegation_info.delegation_by_operator, expected_delegation_by_op_dict 58 ) 59 60 self.assertIn( 61 "Total delegated subgraphs", 62 delegation_info.get_summary(), 63 ) 64 65 df = delegation_info.get_operator_delegation_dataframe() 66 expected_df = pd.DataFrame( 67 { 68 "op_type": [ 69 "aten_add_tensor", 70 "aten_mm_default", 71 "aten_sub_tensor", 72 "getitem", 73 "Total", 74 ], 75 "occurrences_in_delegated_graphs": [2, 2, 0, 0, 4], 76 "occurrences_in_non_delegated_graphs": [0, 0, 1, 2, 3], 77 } 78 ) 79 assert_frame_equal(expected_df, df) 80