1# mypy: allow-untyped-defs 2import torch 3 4from functorch.experimental.control_flow import cond 5 6class CondPredicate(torch.nn.Module): 7 """ 8 The conditional statement (aka predicate) passed to cond() must be one of the following: 9 - torch.Tensor with a single element 10 - boolean expression 11 12 NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. 13 """ 14 15 def forward(self, x): 16 pred = x.dim() > 2 and x.shape[2] > 10 17 18 return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) 19 20example_args = (torch.randn(6, 4, 3),) 21tags = { 22 "torch.cond", 23 "torch.dynamic-shape", 24} 25model = CondPredicate() 26