xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/tensor_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import torch
2
3
4class TensorOpsModule(torch.nn.Module):
5    def forward(self):
6        return self.tensor_general_ops()
7
8    def tensor_general_ops(self):
9        a = torch.randn(4)
10        b = torch.tensor([1.5])
11        x = torch.ones((2,))
12        c = torch.randn(4, dtype=torch.cfloat)
13        w = torch.rand(4, 4, 4, 4)
14        v = torch.rand(4, 4, 4, 4)
15        return len(
16            # torch.is_tensor(a),
17            # torch.is_storage(a),
18            torch.is_complex(a),
19            torch.is_conj(a),
20            torch.is_floating_point(a),
21            torch.is_nonzero(b),
22            # torch.set_default_dtype(torch.float32),
23            # torch.get_default_dtype(),
24            # torch.set_default_tensor_type(torch.DoubleTensor),
25            torch.numel(a),
26            # torch.set_printoptions(),
27            # torch.set_flush_denormal(False),
28            # https://pytorch.org/docs/stable/tensors.html#tensor-class-reference
29            # x.new_tensor([[0, 1], [2, 3]]),
30            x.new_full((3, 4), 3.141592),
31            x.new_empty((2, 3)),
32            x.new_ones((2, 3)),
33            x.new_zeros((2, 3)),
34            x.is_cuda,
35            x.is_quantized,
36            x.is_meta,
37            x.device,
38            x.dim(),
39            c.real,
40            c.imag,
41            # x.backward(),
42            x.clone(),
43            w.contiguous(),
44            w.contiguous(memory_format=torch.channels_last),
45            w.copy_(v),
46            w.copy_(1),
47            w.copy_(0.5),
48            x.cpu(),
49            # x.cuda(),
50            # x.data_ptr(),
51            x.dense_dim(),
52            w.fill_diagonal_(0),
53            w.element_size(),
54            w.exponential_(),
55            w.fill_(0),
56            w.geometric_(0.5),
57            a.index_fill(0, torch.tensor([0, 2]), 1),
58            a.index_put_([torch.argmax(a)], torch.tensor(1.0)),
59            a.index_put([torch.argmax(a)], torch.tensor(1.0)),
60            w.is_contiguous(),
61            c.is_complex(),
62            w.is_conj(),
63            w.is_floating_point(),
64            w.is_leaf,
65            w.is_pinned(),
66            w.is_set_to(w),
67            # w.is_shared,
68            w.is_coalesced(),
69            w.coalesce(),
70            w.is_signed(),
71            w.is_sparse,
72            torch.tensor([1]).item(),
73            x.log_normal_(),
74            # x.masked_scatter_(),
75            # x.masked_scatter(),
76            # w.normal(),
77            w.numel(),
78            # w.pin_memory(),
79            # w.put_(0, torch.tensor([0, 1], w)),
80            x.repeat(4, 2),
81            a.clamp_(0),
82            a.clamp(0),
83            a.clamp_min(0),
84            a.hardsigmoid_(),
85            a.hardsigmoid(),
86            a.hardswish_(),
87            a.hardswish(),
88            a.hardtanh_(),
89            a.hardtanh(),
90            a.leaky_relu_(),
91            a.leaky_relu(),
92            a.relu_(),
93            a.relu(),
94            a.resize_as_(a),
95            a.type_as(a),
96            a._shape_as_tensor(),
97            a.requires_grad_(False),
98        )
99
100
101class TensorCreationOpsModule(torch.nn.Module):
102    def forward(self):
103        return self.tensor_creation_ops()
104
105    def tensor_creation_ops(self):
106        i = torch.tensor([[0, 1, 1], [2, 0, 2]])
107        v = torch.tensor([3, 4, 5], dtype=torch.float32)
108        real = torch.tensor([1, 2], dtype=torch.float32)
109        imag = torch.tensor([3, 4], dtype=torch.float32)
110        inp = torch.tensor([-1.5, 0.0, 2.0])
111        values = torch.tensor([0.5])
112        quantized = torch.quantize_per_channel(
113            torch.tensor([[-1.0, 0.0], [1.0, 2.0]]),
114            torch.tensor([0.1, 0.01]),
115            torch.tensor([10, 0]),
116            0,
117            torch.quint8,
118        )
119        return len(
120            torch.tensor([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]]),
121            # torch.sparse_coo_tensor(i, v, [2, 3]), # not work for iOS
122            torch.as_tensor([1, 2, 3]),
123            torch.as_strided(torch.randn(3, 3), (2, 2), (1, 2)),
124            torch.zeros(2, 3),
125            torch.zeros((2, 3)),
126            torch.zeros([2, 3], out=i),
127            torch.zeros(5),
128            torch.zeros_like(torch.empty(2, 3)),
129            torch.ones(2, 3),
130            torch.ones((2, 3)),
131            torch.ones([2, 3]),
132            torch.ones(5),
133            torch.ones_like(torch.empty(2, 3)),
134            torch.arange(5),
135            torch.arange(1, 4),
136            torch.arange(1, 2.5, 0.5),
137            torch.range(1, 4),
138            torch.range(1, 4, 0.5),
139            torch.linspace(3.0, 3.0, steps=1),
140            torch.logspace(start=2, end=2, steps=1, base=2.0),
141            torch.eye(3),
142            torch.empty(2, 3),
143            torch.empty_like(torch.empty(2, 3), dtype=torch.int64),
144            torch.empty_strided((2, 3), (1, 2)),
145            torch.full((2, 3), 3.141592),
146            torch.full_like(torch.full((2, 3), 3.141592), 2.71828),
147            torch.quantize_per_tensor(
148                torch.tensor([-1.0, 0.0, 1.0, 2.0]), 0.1, 10, torch.quint8
149            ),
150            torch.dequantize(quantized),
151            torch.complex(real, imag),
152            torch.polar(real, imag),
153            torch.heaviside(inp, values),
154        )
155
156
157class TensorIndexingOpsModule(torch.nn.Module):
158    def forward(self):
159        return self.tensor_indexing_ops()
160
161    def tensor_indexing_ops(self):
162        x = torch.randn(2, 4)
163        y = torch.randn(4, 4)
164        t = torch.tensor([[0, 0], [1, 0]])
165        mask = x.ge(0.5)
166        i = [0, 1]
167        return len(
168            torch.cat((x, x, x), 0),
169            torch.concat((x, x, x), 0),
170            torch.conj(x),
171            torch.chunk(x, 2),
172            torch.dsplit(torch.randn(2, 2, 4), i),
173            torch.column_stack((x, x)),
174            torch.dstack((x, x)),
175            torch.gather(x, 0, t),
176            torch.hsplit(x, i),
177            torch.hstack((x, x)),
178            torch.index_select(x, 0, torch.tensor([0, 1])),
179            x.index(t),
180            torch.masked_select(x, mask),
181            torch.movedim(x, 1, 0),
182            torch.moveaxis(x, 1, 0),
183            torch.narrow(x, 0, 0, 2),
184            torch.nonzero(x),
185            torch.permute(x, (0, 1)),
186            torch.reshape(x, (-1,)),
187            torch.row_stack((x, x)),
188            torch.select(x, 0, 0),
189            torch.scatter(x, 0, t, x),
190            x.scatter(0, t, x.clone()),
191            torch.diagonal_scatter(y, torch.ones(4)),
192            torch.select_scatter(y, torch.ones(4), 0, 0),
193            torch.slice_scatter(x, x),
194            torch.scatter_add(x, 0, t, x),
195            x.scatter_(0, t, y),
196            x.scatter_add_(0, t, y),
197            # torch.scatter_reduce(x, 0, t, reduce="sum"),
198            torch.split(x, 1),
199            torch.squeeze(x, 0),
200            torch.stack([x, x]),
201            torch.swapaxes(x, 0, 1),
202            torch.swapdims(x, 0, 1),
203            torch.t(x),
204            torch.take(x, t),
205            torch.take_along_dim(x, torch.argmax(x)),
206            torch.tensor_split(x, 1),
207            torch.tensor_split(x, [0, 1]),
208            torch.tile(x, (2, 2)),
209            torch.transpose(x, 0, 1),
210            torch.unbind(x),
211            torch.unsqueeze(x, -1),
212            torch.vsplit(x, i),
213            torch.vstack((x, x)),
214            torch.where(x),
215            torch.where(t > 0, t, 0),
216            torch.where(t > 0, t, t),
217        )
218
219
220class TensorTypingOpsModule(torch.nn.Module):
221    def forward(self):
222        return self.tensor_typing_ops()
223
224    def tensor_typing_ops(self):
225        x = torch.randn(1, 3, 4, 4)
226        return len(
227            x.to(torch.float),
228            x.to(torch.double),
229            x.to(torch.cfloat),
230            x.to(torch.cdouble),
231            x.to(torch.half),
232            x.to(torch.bfloat16),
233            x.to(torch.uint8),
234            x.to(torch.int8),
235            x.to(torch.short),
236            x.to(torch.int),
237            x.to(torch.long),
238            x.to(torch.bool),
239            x.to(torch.device("cpu")),
240            x.to(device="cpu", dtype=torch.float),
241            x.to(memory_format=torch.channels_last),
242        )
243
244
245class TensorViewOpsModule(torch.nn.Module):
246    def forward(self):
247        return self.tensor_view_ops()
248
249    def tensor_view_ops(self):
250        x = torch.randn(4, 4, 1)
251        y = torch.randn(4, 4, 2)
252        return len(
253            x[0, 2:],
254            x.detach(),
255            x.detach_(),
256            x.diagonal(),
257            x.expand(-1, -1, 3),
258            x.expand_as(y),
259            x.select(0, 1),
260            x.unflatten(1, (2, 2)),
261            x.unfold(1, 2, 2),
262            x.view(16),
263            x.view_as(torch.randn(16)),
264        )
265