xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/builtin_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3
4# https://pytorch.org/docs/stable/jit_builtin_functions.html#builtin-functions
5
6
7class TSBuiltinOpsModule(torch.nn.Module):
8    def forward(self):
9        x = torch.tensor(1)
10        y = torch.tensor(0.5)
11        b = float(1)
12        s = "abcde"
13        l = ["1", "2", "test", "a{}b"]
14        d = {"key": 1}
15        d2 = {0: 100}
16        return len(
17            # type
18            bool(x),
19            bool(x.item()),
20            int(y),
21            int(y.item()),
22            float(x),
23            float(x.item()),
24            # math
25            x & x,
26            bool(x) & bool(x),
27            int(x) & int(x),
28            x | x,
29            bool(x) | bool(x),
30            int(x) | int(x),
31            x << x,
32            int(x) << int(x),
33            x >> x,
34            int(x) >> int(x),
35            x ^ x,
36            bool(x) ^ bool(x),
37            int(x) ^ int(x),
38            b * float(x),
39            b * int(x),
40            b + float(x),
41            b - float(x),
42            x.item() + y.item(),
43            x.item() - y.item(),
44            x.item() * y.item(),
45            x.item() / y.item(),
46            float(x) < float(y),
47            float(x) <= float(y),
48            float(x) > float(y),
49            float(x) > int(y),
50            float(x) >= float(y),
51            float(x) >= int(y),
52            float(x) == float(y),
53            float(x) == int(y),
54            float(x) != float(y),
55            int(x) != float(y),
56            float(x) / float(y),
57            int(x) / int(y),
58            max(x),
59            max(x.item(), y.item()),
60            max(int(x), int(y)),
61            max(float(x), float(y)),
62            min(x),
63            min(x.item(), y.item()),
64            min(int(x), int(y)),
65            min(float(x), float(y)),
66            int(l[0]),
67            float(l[0]),
68            # string
69            str(torch.tensor(1)),
70            l[2].find("t"),
71            l[2].replace("t", "x"),
72            l[2].lower(),
73            l[2].startswith("t"),
74            l[2].split("t"),
75            l[2].strip(),
76            l[2].rstrip(),
77            l[2].lstrip(),
78            l[2][slice(2)],
79            l[3].format("x"),
80            ord(l[2][0]),
81            len(torch.randn(3)),
82            len(l),
83            len(l[2]),
84            len(d),
85            len(d2),
86        )
87
88
89class TSCollectionOpsModule(torch.nn.Module):
90    def forward(self):
91        s = "abcde"
92        # list
93        l = ["1", "2", "test"]
94        l.reverse()
95        l.reverse()
96        l[1] = "3"
97        l.extend(["4"])
98        # str dict
99        d = {"key": 1}
100        d.clear()
101        d.update({"key": 0})
102        if "key" in d:
103            d["key"] = 2
104        #  int dict
105        d2 = {0: 100}
106        if 0 in d2:
107            d2.clear()
108            d2[0] = 100
109
110        return len(
111            s[torch.tensor(1)],
112            d["key"],
113            d2[0],
114            d.keys(),
115            d.items(),
116            d.values(),
117            d2.values(),
118            l.pop(),
119        )
120