1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from torch.nn import Module # @manual 9 10 11class FTCondBasic(Module): 12 def __init__(self): 13 super().__init__() 14 15 def forward(self, inp): 16 def true_branch(x): 17 return x + x 18 19 def false_branch(x): 20 return x * x 21 22 return torch.ops.higher_order.cond( 23 inp.sum() > 4, true_branch, false_branch, [inp] 24 ) 25 26 def get_random_inputs(self): 27 return (torch.rand(5),) 28 29 30class FTCondDynShape(Module): 31 def __init__(self): 32 super().__init__() 33 34 def forward(self, inp): 35 def true_branch(x): 36 return x + x + x 37 38 def false_branch(x): 39 return x * x * x 40 41 return torch.ops.higher_order.cond( 42 inp.sum() > 4, true_branch, false_branch, [inp] 43 ) 44 45 def get_upper_bound_inputs(self): 46 return (torch.rand(8),) 47 48 def get_random_inputs(self): 49 return (torch.rand(5),) 50 51 52class FTCondDeadCode(Module): 53 """ 54 A toy model used to test DCE on sub modules. 55 56 The graph generated for torch.inverse will contain a node: 57 torch.ops.aten._linalg_check_errors.default 58 to check for errors. There are no out variants for this op and executorch 59 runtime does not support it. For now, we simply erase this node by DCE 60 since the Fx code does not consider this node as having side effect. 61 """ 62 63 def __init__(self): 64 super().__init__() 65 66 def forward(self, inp): 67 def true_branch(x): 68 x - 1 69 return x + 1 70 71 def false_branch(x): 72 return x * 2 73 74 return torch.ops.higher_order.cond( 75 inp.sum() > 4, true_branch, false_branch, [inp] 76 ) 77 78 def get_random_inputs(self): 79 return (torch.eye(5) * 2,) 80 81 82class FTMapBasic(Module): 83 def __init__(self): 84 super().__init__() 85 86 def forward(self, xs, y): 87 def f(x, y): 88 return x + y 89 90 return torch.ops.higher_order.map(f, xs, y) + xs 91 92 def get_random_inputs(self): 93 return torch.rand(2, 4), torch.rand(4) 94 95 96class FTMapDynShape(Module): 97 def __init__(self): 98 super().__init__() 99 100 def forward(self, xs, y): 101 def f(x, y): 102 return x + y 103 104 return torch.ops.higher_order.map(f, xs, y) + xs 105 106 def get_upper_bound_inputs(self): 107 return torch.rand(4, 4), torch.rand(4) 108 109 def get_random_inputs(self): 110 return torch.rand(2, 4), torch.rand(4) 111