1# mypy: allow-untyped-defs 2import torch 3 4from functorch.experimental.control_flow import cond 5 6class MySubModule(torch.nn.Module): 7 def foo(self, x): 8 return x.cos() 9 10 def forward(self, x): 11 return self.foo(x) 12 13class CondBranchClassMethod(torch.nn.Module): 14 """ 15 The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: 16 - both branches must take the same args, which must also match the branch args passed to cond. 17 - both branches must return a single tensor 18 - returned tensor must have the same tensor metadata, e.g. shape and dtype 19 - branch function can be free function, nested function, lambda, class methods 20 - branch function can not have closure variables 21 - no inplace mutations on inputs or global variables 22 23 24 This example demonstrates using class method in cond(). 25 26 NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. 27 """ 28 29 def __init__(self) -> None: 30 super().__init__() 31 self.subm = MySubModule() 32 33 def bar(self, x): 34 return x.sin() 35 36 def forward(self, x): 37 return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x]) 38 39example_args = (torch.randn(3),) 40tags = { 41 "torch.cond", 42 "torch.dynamic-shape", 43} 44model = CondBranchClassMethod() 45