xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/math_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# https://pytorch.org/docs/stable/torch.html#math-operations
2
3import math
4
5import torch
6
7
8class PointwiseOpsModule(torch.nn.Module):
9    def forward(self):
10        return self.pointwise_ops()
11
12    def pointwise_ops(self):
13        a = torch.randn(4)
14        b = torch.randn(4)
15        t = torch.tensor([-1, -2, 3], dtype=torch.int8)
16        r = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
17        t = torch.tensor([-1, -2, 3], dtype=torch.int8)
18        s = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
19        f = torch.zeros(3)
20        g = torch.tensor([-1, 0, 1])
21        w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637])
22        return len(
23            torch.abs(torch.tensor([-1, -2, 3])),
24            torch.absolute(torch.tensor([-1, -2, 3])),
25            torch.acos(a),
26            torch.arccos(a),
27            torch.acosh(a.uniform_(1.0, 2.0)),
28            torch.add(a, 20),
29            torch.add(a, b, out=a),
30            b.add(a),
31            b.add(a, out=b),
32            b.add_(a),
33            b.add(1),
34            torch.add(a, torch.randn(4, 1), alpha=10),
35            torch.addcdiv(
36                torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1
37            ),
38            torch.addcmul(
39                torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1
40            ),
41            torch.angle(a),
42            torch.asin(a),
43            torch.arcsin(a),
44            torch.asinh(a),
45            torch.arcsinh(a),
46            torch.atan(a),
47            torch.arctan(a),
48            torch.atanh(a.uniform_(-1.0, 1.0)),
49            torch.arctanh(a.uniform_(-1.0, 1.0)),
50            torch.atan2(a, a),
51            torch.bitwise_not(t),
52            torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
53            torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
54            torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)),
55            torch.ceil(a),
56            torch.ceil(float(torch.tensor(0.5))),
57            torch.ceil(torch.tensor(0.5).item()),
58            torch.clamp(a, min=-0.5, max=0.5),
59            torch.clamp(a, min=0.5),
60            torch.clamp(a, max=0.5),
61            torch.clip(a, min=-0.5, max=0.5),
62            torch.conj(a),
63            torch.copysign(a, 1),
64            torch.copysign(a, b),
65            torch.cos(a),
66            torch.cosh(a),
67            torch.deg2rad(
68                torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]])
69            ),
70            torch.div(a, b),
71            a.div(b),
72            a.div(1),
73            a.div_(b),
74            torch.divide(a, b, rounding_mode="trunc"),
75            torch.divide(a, b, rounding_mode="floor"),
76            torch.digamma(torch.tensor([1.0, 0.5])),
77            torch.erf(torch.tensor([0.0, -1.0, 10.0])),
78            torch.erfc(torch.tensor([0.0, -1.0, 10.0])),
79            torch.erfinv(torch.tensor([0.0, 0.5, -1.0])),
80            torch.exp(torch.tensor([0.0, math.log(2.0)])),
81            torch.exp(float(torch.tensor(1))),
82            torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])),
83            torch.expm1(torch.tensor([0.0, math.log(2.0)])),
84            torch.fake_quantize_per_channel_affine(
85                torch.randn(2, 2, 2),
86                (torch.randn(2) + 1) * 0.05,
87                torch.zeros(2),
88                1,
89                0,
90                255,
91            ),
92            torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255),
93            torch.float_power(torch.randint(10, (4,)), 2),
94            torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4, -5])),
95            torch.floor(a),
96            torch.floor(float(torch.tensor(1))),
97            torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])),
98            torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4),
99            torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2),
100            torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5),
101            torch.frac(torch.tensor([1.0, 2.5, -3.2])),
102            torch.randn(4, dtype=torch.cfloat).imag,
103            torch.ldexp(torch.tensor([1.0]), torch.tensor([1])),
104            torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])),
105            torch.lerp(torch.arange(1.0, 5.0), torch.empty(4).fill_(10), 0.5),
106            torch.lerp(
107                torch.arange(1.0, 5.0),
108                torch.empty(4).fill_(10),
109                torch.full_like(torch.arange(1.0, 5.0), 0.5),
110            ),
111            torch.lgamma(torch.arange(0.5, 2, 0.5)),
112            torch.log(torch.arange(5) + 10),
113            torch.log10(torch.rand(5)),
114            torch.log1p(torch.randn(5)),
115            torch.log2(torch.rand(5)),
116            torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
117            torch.logaddexp(
118                torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])
119            ),
120            torch.logaddexp(
121                torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])
122            ),
123            torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])),
124            torch.logaddexp2(
125                torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])
126            ),
127            torch.logaddexp2(
128                torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])
129            ),
130            torch.logical_and(r, s),
131            torch.logical_and(r.double(), s.double()),
132            torch.logical_and(r.double(), s),
133            torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)),
134            torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)),
135            torch.logical_not(torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)),
136            torch.logical_not(
137                torch.tensor([0.0, 1.0, -10.0], dtype=torch.double),
138                out=torch.empty(3, dtype=torch.int16),
139            ),
140            torch.logical_or(r, s),
141            torch.logical_or(r.double(), s.double()),
142            torch.logical_or(r.double(), s),
143            torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)),
144            torch.logical_xor(r, s),
145            torch.logical_xor(r.double(), s.double()),
146            torch.logical_xor(r.double(), s),
147            torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)),
148            torch.logit(torch.rand(5), eps=1e-6),
149            torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])),
150            torch.i0(torch.arange(5, dtype=torch.float32)),
151            torch.igamma(a, b),
152            torch.igammac(a, b),
153            torch.mul(torch.randn(3), 100),
154            b.mul(a),
155            b.mul(5),
156            b.mul(a, out=b),
157            b.mul_(a),
158            b.mul_(5),
159            torch.multiply(torch.randn(4, 1), torch.randn(1, 4)),
160            torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2),
161            torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14]),
162            torch.nan_to_num(w),
163            torch.nan_to_num_(w),
164            torch.nan_to_num(w, nan=2.0),
165            torch.nan_to_num(w, nan=2.0, posinf=1.0),
166            torch.neg(torch.randn(5)),
167            # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]),
168            torch.polygamma(1, torch.tensor([1.0, 0.5])),
169            torch.polygamma(2, torch.tensor([1.0, 0.5])),
170            torch.polygamma(3, torch.tensor([1.0, 0.5])),
171            torch.polygamma(4, torch.tensor([1.0, 0.5])),
172            torch.pow(a, 2),
173            torch.pow(2, float(torch.tensor(0.5))),
174            torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)),
175            torch.rad2deg(
176                torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])
177            ),
178            torch.randn(4, dtype=torch.cfloat).real,
179            torch.reciprocal(a),
180            torch.remainder(torch.tensor([-3.0, -2.0]), 2),
181            torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5),
182            torch.round(a),
183            torch.round(torch.tensor(0.5).item()),
184            torch.rsqrt(a),
185            torch.sigmoid(a),
186            torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])),
187            torch.sgn(a),
188            torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])),
189            torch.sin(a),
190            torch.sinc(a),
191            torch.sinh(a),
192            torch.sqrt(a),
193            torch.square(a),
194            torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2),
195            b.sub(a),
196            b.sub_(a),
197            b.sub(5),
198            torch.sum(5),
199            torch.tan(a),
200            torch.tanh(a),
201            torch.true_divide(a, a),
202            torch.trunc(a),
203            torch.trunc_(a),
204            torch.xlogy(f, g),
205            torch.xlogy(f, g),
206            torch.xlogy(f, 4),
207            torch.xlogy(2, g),
208        )
209
210
211class ReductionOpsModule(torch.nn.Module):
212    def forward(self):
213        return self.reduction_ops()
214
215    def reduction_ops(self):
216        a = torch.randn(4)
217        b = torch.randn(4)
218        c = torch.tensor(0.5)
219        return len(
220            torch.argmax(a),
221            torch.argmin(a),
222            torch.amax(a),
223            torch.amin(a),
224            torch.aminmax(a),
225            torch.all(a),
226            torch.any(a),
227            torch.max(a),
228            a.max(a),
229            torch.max(a, 0),
230            torch.min(a),
231            a.min(a),
232            torch.min(a, 0),
233            torch.dist(a, b),
234            torch.logsumexp(a, 0),
235            torch.mean(a),
236            torch.mean(a, 0),
237            torch.nanmean(a),
238            torch.median(a),
239            torch.nanmedian(a),
240            torch.mode(a),
241            torch.norm(a),
242            a.norm(2),
243            torch.norm(a, dim=0),
244            torch.norm(c, torch.tensor(2)),
245            torch.nansum(a),
246            torch.prod(a),
247            torch.quantile(a, torch.tensor([0.25, 0.5, 0.75])),
248            torch.quantile(a, 0.5),
249            torch.nanquantile(a, torch.tensor([0.25, 0.5, 0.75])),
250            torch.std(a),
251            torch.std_mean(a),
252            torch.sum(a),
253            torch.unique(a),
254            torch.unique_consecutive(a),
255            torch.var(a),
256            torch.var_mean(a),
257            torch.count_nonzero(a),
258        )
259
260
261class ComparisonOpsModule(torch.nn.Module):
262    def forward(self):
263        a = torch.tensor(0)
264        b = torch.tensor(1)
265        return len(
266            torch.allclose(a, b),
267            torch.argsort(a),
268            torch.eq(a, b),
269            torch.eq(a, 1),
270            torch.equal(a, b),
271            torch.ge(a, b),
272            torch.ge(a, 1),
273            torch.greater_equal(a, b),
274            torch.greater_equal(a, 1),
275            torch.gt(a, b),
276            torch.gt(a, 1),
277            torch.greater(a, b),
278            torch.isclose(a, b),
279            torch.isfinite(a),
280            torch.isin(a, b),
281            torch.isinf(a),
282            torch.isposinf(a),
283            torch.isneginf(a),
284            torch.isnan(a),
285            torch.isreal(a),
286            torch.kthvalue(a, 1),
287            torch.le(a, b),
288            torch.le(a, 1),
289            torch.less_equal(a, b),
290            torch.lt(a, b),
291            torch.lt(a, 1),
292            torch.less(a, b),
293            torch.maximum(a, b),
294            torch.minimum(a, b),
295            torch.fmax(a, b),
296            torch.fmin(a, b),
297            torch.ne(a, b),
298            torch.ne(a, 1),
299            torch.not_equal(a, b),
300            torch.sort(a),
301            torch.topk(a, 1),
302            torch.msort(a),
303        )
304
305
306class OtherMathOpsModule(torch.nn.Module):
307    def forward(self):
308        return self.other_ops()
309
310    def other_ops(self):
311        a = torch.randn(4)
312        b = torch.randn(4)
313        c = torch.randint(0, 8, (5,), dtype=torch.int64)
314        e = torch.randn(4, 3)
315        f = torch.randn(4, 4, 4)
316        size = [0, 1]
317        dims = [0, 1]
318        return len(
319            torch.atleast_1d(a),
320            torch.atleast_2d(a),
321            torch.atleast_3d(a),
322            torch.bincount(c),
323            torch.block_diag(a),
324            torch.broadcast_tensors(a),
325            torch.broadcast_to(a, (4)),
326            # torch.broadcast_shapes(a),
327            torch.bucketize(a, b),
328            torch.cartesian_prod(a),
329            torch.cdist(e, e),
330            torch.clone(a),
331            torch.combinations(a),
332            torch.corrcoef(a),
333            # torch.cov(a),
334            torch.cross(e, e),
335            torch.cummax(a, 0),
336            torch.cummin(a, 0),
337            torch.cumprod(a, 0),
338            torch.cumsum(a, 0),
339            torch.diag(a),
340            torch.diag_embed(a),
341            torch.diagflat(a),
342            torch.diagonal(e),
343            torch.diff(a),
344            torch.einsum("iii", f),
345            torch.flatten(a),
346            torch.flip(e, dims),
347            torch.fliplr(e),
348            torch.flipud(e),
349            torch.kron(a, b),
350            torch.rot90(e),
351            torch.gcd(c, c),
352            torch.histc(a),
353            torch.histogram(a),
354            torch.meshgrid(a),
355            torch.meshgrid(a, indexing="xy"),
356            torch.lcm(c, c),
357            torch.logcumsumexp(a, 0),
358            torch.ravel(a),
359            torch.renorm(e, 1, 0, 5),
360            torch.repeat_interleave(c),
361            torch.roll(a, 1, 0),
362            torch.searchsorted(a, b),
363            torch.tensordot(e, e),
364            torch.trace(e),
365            torch.tril(e),
366            torch.tril_indices(3, 3),
367            torch.triu(e),
368            torch.triu_indices(3, 3),
369            torch.vander(a),
370            torch.view_as_real(torch.randn(4, dtype=torch.cfloat)),
371            torch.view_as_complex(torch.randn(4, 2)).real,
372            torch.resolve_conj(a),
373            torch.resolve_neg(a),
374        )
375
376
377class SpectralOpsModule(torch.nn.Module):
378    def forward(self):
379        return self.spectral_ops()
380
381    def spectral_ops(self):
382        a = torch.randn(10)
383        b = torch.randn(10, 8, 4, 2)
384        return len(
385            torch.stft(a, 8),
386            torch.stft(a, torch.tensor(8)),
387            torch.istft(b, 8),
388            torch.bartlett_window(2, dtype=torch.float),
389            torch.blackman_window(2, dtype=torch.float),
390            torch.hamming_window(4, dtype=torch.float),
391            torch.hann_window(4, dtype=torch.float),
392            torch.kaiser_window(4, dtype=torch.float),
393        )
394
395
396class BlasLapackOpsModule(torch.nn.Module):
397    def forward(self):
398        return self.blas_lapack_ops()
399
400    def blas_lapack_ops(self):
401        m = torch.randn(3, 3)
402        a = torch.randn(10, 3, 4)
403        b = torch.randn(10, 4, 3)
404        v = torch.randn(3)
405        return len(
406            torch.addbmm(m, a, b),
407            torch.addmm(torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3)),
408            torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)),
409            torch.addr(torch.zeros(3, 3), v, v),
410            torch.baddbmm(m, a, b),
411            torch.bmm(a, b),
412            torch.chain_matmul(torch.randn(3, 3), torch.randn(3, 3), torch.randn(3, 3)),
413            # torch.cholesky(a), # deprecated
414            # torch.cholesky_inverse(torch.randn(3, 3)), # had some error
415            # torch.cholesky_solve(torch.randn(3, 3), torch.randn(3, 3)),
416            torch.dot(v, v),
417            # torch.linalg.eig(m), # not build with lapack
418            # torch.geqrf(a),
419            torch.ger(v, v),
420            torch.inner(m, m),
421            # torch.inverse(m),
422            # torch.det(m),
423            # torch.logdet(m),
424            # torch.slogdet(m),
425            # torch.lstsq(m, m),
426            # torch.linalg.lu_factor(m),
427            # torch.lu_solve(m, *torch.linalg.lu_factor(m)),
428            # torch.lu_unpack(*torch.linalg.lu_factor(m)),
429            torch.matmul(m, m),
430            torch.matrix_power(m, 2),
431            # torch.matrix_rank(m),
432            torch.matrix_exp(m),
433            torch.mm(m, m),
434            torch.mv(m, v),
435            # torch.orgqr(a, m),
436            # torch.ormqr(a, m, v),
437            torch.outer(v, v),
438            # torch.pinverse(m),
439            # torch.qr(a),
440            # torch.solve(m, m),
441            # torch.svd(a),
442            # torch.svd_lowrank(a),
443            # torch.pca_lowrank(a),
444            # torch.symeig(a), # deprecated
445            # torch.lobpcg(a, b), # not supported
446            torch.trapz(m, m),
447            torch.trapezoid(m, m),
448            torch.cumulative_trapezoid(m, m),
449            # torch.triangular_solve(m, m),
450            torch.vdot(v, v),
451        )
452