1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import re 10import unittest 11 12import torch 13import torch.fx 14 15from executorch.exir.common import extract_out_arguments, get_schema_for_operators 16from executorch.exir.print_program import add_cursor_to_graph 17 18 19class TestExirCommon(unittest.TestCase): 20 def test_get_schema_for_operators(self) -> None: 21 op_list = [ 22 "torch.ops._caffe2.RoIAlign.default", 23 "torch.ops.aten.add.Tensor", 24 "torch.ops.aten.batch_norm.default", 25 "torch.ops.aten.cat.default", 26 "torch.ops.aten.clamp.default", 27 ] 28 29 schemas = get_schema_for_operators(op_list) 30 pat = re.compile(r"[^\(]+\([^\)]+\) -> ") 31 for _op_name, schema in schemas.items(): 32 self.assertIsNotNone(re.match(pat, schema)) 33 34 def test_get_out_args(self) -> None: 35 schema1 = torch._C.parse_schema( 36 "aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)" 37 ) 38 schema2 = torch._C.parse_schema( 39 "split_copy.Tensor_out(Tensor self, int split_size, int dim=0, *, Tensor(a!)[] out) -> ()" 40 ) 41 42 out_args_1 = extract_out_arguments(schema1, {"out": torch.ones(5)}) 43 out_args_2 = extract_out_arguments( 44 schema2, {"out": [torch.ones(5), torch.ones(5)]} 45 ) 46 47 out_arg_name_1, _ = out_args_1 48 self.assertEqual(out_arg_name_1, "out") 49 50 out_arg_name_2, _ = out_args_2 51 self.assertEqual(out_arg_name_2, "out") 52 53 def test_add_cursor(self) -> None: 54 class MyModule(torch.nn.Module): 55 def __init__(self): 56 super().__init__() 57 self.param = torch.nn.Parameter(torch.rand(3, 4)) 58 self.linear = torch.nn.Linear(4, 5) 59 60 def forward(self, x): 61 return self.linear(x + self.param).clamp(min=0.0, max=1.0) 62 63 module = MyModule() 64 65 from torch.fx import symbolic_trace 66 67 symbolic_traced = symbolic_trace(module) 68 69 # Graph we are testing: 70 # graph(): 71 # %x : [#users=1] = placeholder[target=x] 72 # %param : [#users=1] = get_attr[target=param] 73 # %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) 74 # --> %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {}) 75 # %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) 76 # return clamp 77 78 actual_str = add_cursor_to_graph( 79 symbolic_traced.graph, list(symbolic_traced.graph.nodes)[3] 80 ) 81 self.assertTrue(actual_str.split("\n")[4].startswith("-->")) 82