xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/cond_branch_class_method.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4from functorch.experimental.control_flow import cond
5
6class MySubModule(torch.nn.Module):
7    def foo(self, x):
8        return x.cos()
9
10    def forward(self, x):
11        return self.foo(x)
12
13class CondBranchClassMethod(torch.nn.Module):
14    """
15    The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules:
16      - both branches must take the same args, which must also match the branch args passed to cond.
17      - both branches must return a single tensor
18      - returned tensor must have the same tensor metadata, e.g. shape and dtype
19      - branch function can be free function, nested function, lambda, class methods
20      - branch function can not have closure variables
21      - no inplace mutations on inputs or global variables
22
23
24    This example demonstrates using class method in cond().
25
26    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
27    """
28
29    def __init__(self) -> None:
30        super().__init__()
31        self.subm = MySubModule()
32
33    def bar(self, x):
34        return x.sin()
35
36    def forward(self, x):
37        return cond(x.shape[0] <= 2, self.subm.forward, self.bar, [x])
38
39example_args = (torch.randn(3),)
40tags = {
41    "torch.cond",
42    "torch.dynamic-shape",
43}
44model = CondBranchClassMethod()
45