1# https://pytorch.org/docs/stable/torch.html#math-operations 2 3import math 4 5import torch 6 7 8class PointwiseOpsModule(torch.nn.Module): 9 def forward(self): 10 return self.pointwise_ops() 11 12 def pointwise_ops(self): 13 a = torch.randn(4) 14 b = torch.randn(4) 15 t = torch.tensor([-1, -2, 3], dtype=torch.int8) 16 r = torch.tensor([0, 1, 10, 0], dtype=torch.int8) 17 t = torch.tensor([-1, -2, 3], dtype=torch.int8) 18 s = torch.tensor([4, 0, 1, 0], dtype=torch.int8) 19 f = torch.zeros(3) 20 g = torch.tensor([-1, 0, 1]) 21 w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) 22 return len( 23 torch.abs(torch.tensor([-1, -2, 3])), 24 torch.absolute(torch.tensor([-1, -2, 3])), 25 torch.acos(a), 26 torch.arccos(a), 27 torch.acosh(a.uniform_(1.0, 2.0)), 28 torch.add(a, 20), 29 torch.add(a, b, out=a), 30 b.add(a), 31 b.add(a, out=b), 32 b.add_(a), 33 b.add(1), 34 torch.add(a, torch.randn(4, 1), alpha=10), 35 torch.addcdiv( 36 torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1 37 ), 38 torch.addcmul( 39 torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1 40 ), 41 torch.angle(a), 42 torch.asin(a), 43 torch.arcsin(a), 44 torch.asinh(a), 45 torch.arcsinh(a), 46 torch.atan(a), 47 torch.arctan(a), 48 torch.atanh(a.uniform_(-1.0, 1.0)), 49 torch.arctanh(a.uniform_(-1.0, 1.0)), 50 torch.atan2(a, a), 51 torch.bitwise_not(t), 52 torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)), 53 torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)), 54 torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)), 55 torch.ceil(a), 56 torch.ceil(float(torch.tensor(0.5))), 57 torch.ceil(torch.tensor(0.5).item()), 58 torch.clamp(a, min=-0.5, max=0.5), 59 torch.clamp(a, min=0.5), 60 torch.clamp(a, max=0.5), 61 torch.clip(a, min=-0.5, max=0.5), 62 torch.conj(a), 63 torch.copysign(a, 1), 64 torch.copysign(a, b), 65 torch.cos(a), 66 torch.cosh(a), 67 torch.deg2rad( 68 torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]]) 69 ), 70 torch.div(a, b), 71 a.div(b), 72 a.div(1), 73 a.div_(b), 74 torch.divide(a, b, rounding_mode="trunc"), 75 torch.divide(a, b, rounding_mode="floor"), 76 torch.digamma(torch.tensor([1.0, 0.5])), 77 torch.erf(torch.tensor([0.0, -1.0, 10.0])), 78 torch.erfc(torch.tensor([0.0, -1.0, 10.0])), 79 torch.erfinv(torch.tensor([0.0, 0.5, -1.0])), 80 torch.exp(torch.tensor([0.0, math.log(2.0)])), 81 torch.exp(float(torch.tensor(1))), 82 torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])), 83 torch.expm1(torch.tensor([0.0, math.log(2.0)])), 84 torch.fake_quantize_per_channel_affine( 85 torch.randn(2, 2, 2), 86 (torch.randn(2) + 1) * 0.05, 87 torch.zeros(2), 88 1, 89 0, 90 255, 91 ), 92 torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255), 93 torch.float_power(torch.randint(10, (4,)), 2), 94 torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4, -5])), 95 torch.floor(a), 96 torch.floor(float(torch.tensor(1))), 97 torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])), 98 torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4), 99 torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2), 100 torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5), 101 torch.frac(torch.tensor([1.0, 2.5, -3.2])), 102 torch.randn(4, dtype=torch.cfloat).imag, 103 torch.ldexp(torch.tensor([1.0]), torch.tensor([1])), 104 torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])), 105 torch.lerp(torch.arange(1.0, 5.0), torch.empty(4).fill_(10), 0.5), 106 torch.lerp( 107 torch.arange(1.0, 5.0), 108 torch.empty(4).fill_(10), 109 torch.full_like(torch.arange(1.0, 5.0), 0.5), 110 ), 111 torch.lgamma(torch.arange(0.5, 2, 0.5)), 112 torch.log(torch.arange(5) + 10), 113 torch.log10(torch.rand(5)), 114 torch.log1p(torch.randn(5)), 115 torch.log2(torch.rand(5)), 116 torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), 117 torch.logaddexp( 118 torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3]) 119 ), 120 torch.logaddexp( 121 torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3]) 122 ), 123 torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), 124 torch.logaddexp2( 125 torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3]) 126 ), 127 torch.logaddexp2( 128 torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3]) 129 ), 130 torch.logical_and(r, s), 131 torch.logical_and(r.double(), s.double()), 132 torch.logical_and(r.double(), s), 133 torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)), 134 torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)), 135 torch.logical_not(torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)), 136 torch.logical_not( 137 torch.tensor([0.0, 1.0, -10.0], dtype=torch.double), 138 out=torch.empty(3, dtype=torch.int16), 139 ), 140 torch.logical_or(r, s), 141 torch.logical_or(r.double(), s.double()), 142 torch.logical_or(r.double(), s), 143 torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)), 144 torch.logical_xor(r, s), 145 torch.logical_xor(r.double(), s.double()), 146 torch.logical_xor(r.double(), s), 147 torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)), 148 torch.logit(torch.rand(5), eps=1e-6), 149 torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])), 150 torch.i0(torch.arange(5, dtype=torch.float32)), 151 torch.igamma(a, b), 152 torch.igammac(a, b), 153 torch.mul(torch.randn(3), 100), 154 b.mul(a), 155 b.mul(5), 156 b.mul(a, out=b), 157 b.mul_(a), 158 b.mul_(5), 159 torch.multiply(torch.randn(4, 1), torch.randn(1, 4)), 160 torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2), 161 torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14]), 162 torch.nan_to_num(w), 163 torch.nan_to_num_(w), 164 torch.nan_to_num(w, nan=2.0), 165 torch.nan_to_num(w, nan=2.0, posinf=1.0), 166 torch.neg(torch.randn(5)), 167 # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]), 168 torch.polygamma(1, torch.tensor([1.0, 0.5])), 169 torch.polygamma(2, torch.tensor([1.0, 0.5])), 170 torch.polygamma(3, torch.tensor([1.0, 0.5])), 171 torch.polygamma(4, torch.tensor([1.0, 0.5])), 172 torch.pow(a, 2), 173 torch.pow(2, float(torch.tensor(0.5))), 174 torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)), 175 torch.rad2deg( 176 torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]]) 177 ), 178 torch.randn(4, dtype=torch.cfloat).real, 179 torch.reciprocal(a), 180 torch.remainder(torch.tensor([-3.0, -2.0]), 2), 181 torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5), 182 torch.round(a), 183 torch.round(torch.tensor(0.5).item()), 184 torch.rsqrt(a), 185 torch.sigmoid(a), 186 torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])), 187 torch.sgn(a), 188 torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])), 189 torch.sin(a), 190 torch.sinc(a), 191 torch.sinh(a), 192 torch.sqrt(a), 193 torch.square(a), 194 torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2), 195 b.sub(a), 196 b.sub_(a), 197 b.sub(5), 198 torch.sum(5), 199 torch.tan(a), 200 torch.tanh(a), 201 torch.true_divide(a, a), 202 torch.trunc(a), 203 torch.trunc_(a), 204 torch.xlogy(f, g), 205 torch.xlogy(f, g), 206 torch.xlogy(f, 4), 207 torch.xlogy(2, g), 208 ) 209 210 211class ReductionOpsModule(torch.nn.Module): 212 def forward(self): 213 return self.reduction_ops() 214 215 def reduction_ops(self): 216 a = torch.randn(4) 217 b = torch.randn(4) 218 c = torch.tensor(0.5) 219 return len( 220 torch.argmax(a), 221 torch.argmin(a), 222 torch.amax(a), 223 torch.amin(a), 224 torch.aminmax(a), 225 torch.all(a), 226 torch.any(a), 227 torch.max(a), 228 a.max(a), 229 torch.max(a, 0), 230 torch.min(a), 231 a.min(a), 232 torch.min(a, 0), 233 torch.dist(a, b), 234 torch.logsumexp(a, 0), 235 torch.mean(a), 236 torch.mean(a, 0), 237 torch.nanmean(a), 238 torch.median(a), 239 torch.nanmedian(a), 240 torch.mode(a), 241 torch.norm(a), 242 a.norm(2), 243 torch.norm(a, dim=0), 244 torch.norm(c, torch.tensor(2)), 245 torch.nansum(a), 246 torch.prod(a), 247 torch.quantile(a, torch.tensor([0.25, 0.5, 0.75])), 248 torch.quantile(a, 0.5), 249 torch.nanquantile(a, torch.tensor([0.25, 0.5, 0.75])), 250 torch.std(a), 251 torch.std_mean(a), 252 torch.sum(a), 253 torch.unique(a), 254 torch.unique_consecutive(a), 255 torch.var(a), 256 torch.var_mean(a), 257 torch.count_nonzero(a), 258 ) 259 260 261class ComparisonOpsModule(torch.nn.Module): 262 def forward(self): 263 a = torch.tensor(0) 264 b = torch.tensor(1) 265 return len( 266 torch.allclose(a, b), 267 torch.argsort(a), 268 torch.eq(a, b), 269 torch.eq(a, 1), 270 torch.equal(a, b), 271 torch.ge(a, b), 272 torch.ge(a, 1), 273 torch.greater_equal(a, b), 274 torch.greater_equal(a, 1), 275 torch.gt(a, b), 276 torch.gt(a, 1), 277 torch.greater(a, b), 278 torch.isclose(a, b), 279 torch.isfinite(a), 280 torch.isin(a, b), 281 torch.isinf(a), 282 torch.isposinf(a), 283 torch.isneginf(a), 284 torch.isnan(a), 285 torch.isreal(a), 286 torch.kthvalue(a, 1), 287 torch.le(a, b), 288 torch.le(a, 1), 289 torch.less_equal(a, b), 290 torch.lt(a, b), 291 torch.lt(a, 1), 292 torch.less(a, b), 293 torch.maximum(a, b), 294 torch.minimum(a, b), 295 torch.fmax(a, b), 296 torch.fmin(a, b), 297 torch.ne(a, b), 298 torch.ne(a, 1), 299 torch.not_equal(a, b), 300 torch.sort(a), 301 torch.topk(a, 1), 302 torch.msort(a), 303 ) 304 305 306class OtherMathOpsModule(torch.nn.Module): 307 def forward(self): 308 return self.other_ops() 309 310 def other_ops(self): 311 a = torch.randn(4) 312 b = torch.randn(4) 313 c = torch.randint(0, 8, (5,), dtype=torch.int64) 314 e = torch.randn(4, 3) 315 f = torch.randn(4, 4, 4) 316 size = [0, 1] 317 dims = [0, 1] 318 return len( 319 torch.atleast_1d(a), 320 torch.atleast_2d(a), 321 torch.atleast_3d(a), 322 torch.bincount(c), 323 torch.block_diag(a), 324 torch.broadcast_tensors(a), 325 torch.broadcast_to(a, (4)), 326 # torch.broadcast_shapes(a), 327 torch.bucketize(a, b), 328 torch.cartesian_prod(a), 329 torch.cdist(e, e), 330 torch.clone(a), 331 torch.combinations(a), 332 torch.corrcoef(a), 333 # torch.cov(a), 334 torch.cross(e, e), 335 torch.cummax(a, 0), 336 torch.cummin(a, 0), 337 torch.cumprod(a, 0), 338 torch.cumsum(a, 0), 339 torch.diag(a), 340 torch.diag_embed(a), 341 torch.diagflat(a), 342 torch.diagonal(e), 343 torch.diff(a), 344 torch.einsum("iii", f), 345 torch.flatten(a), 346 torch.flip(e, dims), 347 torch.fliplr(e), 348 torch.flipud(e), 349 torch.kron(a, b), 350 torch.rot90(e), 351 torch.gcd(c, c), 352 torch.histc(a), 353 torch.histogram(a), 354 torch.meshgrid(a), 355 torch.meshgrid(a, indexing="xy"), 356 torch.lcm(c, c), 357 torch.logcumsumexp(a, 0), 358 torch.ravel(a), 359 torch.renorm(e, 1, 0, 5), 360 torch.repeat_interleave(c), 361 torch.roll(a, 1, 0), 362 torch.searchsorted(a, b), 363 torch.tensordot(e, e), 364 torch.trace(e), 365 torch.tril(e), 366 torch.tril_indices(3, 3), 367 torch.triu(e), 368 torch.triu_indices(3, 3), 369 torch.vander(a), 370 torch.view_as_real(torch.randn(4, dtype=torch.cfloat)), 371 torch.view_as_complex(torch.randn(4, 2)).real, 372 torch.resolve_conj(a), 373 torch.resolve_neg(a), 374 ) 375 376 377class SpectralOpsModule(torch.nn.Module): 378 def forward(self): 379 return self.spectral_ops() 380 381 def spectral_ops(self): 382 a = torch.randn(10) 383 b = torch.randn(10, 8, 4, 2) 384 return len( 385 torch.stft(a, 8), 386 torch.stft(a, torch.tensor(8)), 387 torch.istft(b, 8), 388 torch.bartlett_window(2, dtype=torch.float), 389 torch.blackman_window(2, dtype=torch.float), 390 torch.hamming_window(4, dtype=torch.float), 391 torch.hann_window(4, dtype=torch.float), 392 torch.kaiser_window(4, dtype=torch.float), 393 ) 394 395 396class BlasLapackOpsModule(torch.nn.Module): 397 def forward(self): 398 return self.blas_lapack_ops() 399 400 def blas_lapack_ops(self): 401 m = torch.randn(3, 3) 402 a = torch.randn(10, 3, 4) 403 b = torch.randn(10, 4, 3) 404 v = torch.randn(3) 405 return len( 406 torch.addbmm(m, a, b), 407 torch.addmm(torch.randn(2, 3), torch.randn(2, 3), torch.randn(3, 3)), 408 torch.addmv(torch.randn(2), torch.randn(2, 3), torch.randn(3)), 409 torch.addr(torch.zeros(3, 3), v, v), 410 torch.baddbmm(m, a, b), 411 torch.bmm(a, b), 412 torch.chain_matmul(torch.randn(3, 3), torch.randn(3, 3), torch.randn(3, 3)), 413 # torch.cholesky(a), # deprecated 414 # torch.cholesky_inverse(torch.randn(3, 3)), # had some error 415 # torch.cholesky_solve(torch.randn(3, 3), torch.randn(3, 3)), 416 torch.dot(v, v), 417 # torch.linalg.eig(m), # not build with lapack 418 # torch.geqrf(a), 419 torch.ger(v, v), 420 torch.inner(m, m), 421 # torch.inverse(m), 422 # torch.det(m), 423 # torch.logdet(m), 424 # torch.slogdet(m), 425 # torch.lstsq(m, m), 426 # torch.linalg.lu_factor(m), 427 # torch.lu_solve(m, *torch.linalg.lu_factor(m)), 428 # torch.lu_unpack(*torch.linalg.lu_factor(m)), 429 torch.matmul(m, m), 430 torch.matrix_power(m, 2), 431 # torch.matrix_rank(m), 432 torch.matrix_exp(m), 433 torch.mm(m, m), 434 torch.mv(m, v), 435 # torch.orgqr(a, m), 436 # torch.ormqr(a, m, v), 437 torch.outer(v, v), 438 # torch.pinverse(m), 439 # torch.qr(a), 440 # torch.solve(m, m), 441 # torch.svd(a), 442 # torch.svd_lowrank(a), 443 # torch.pca_lowrank(a), 444 # torch.symeig(a), # deprecated 445 # torch.lobpcg(a, b), # not supported 446 torch.trapz(m, m), 447 torch.trapezoid(m, m), 448 torch.cumulative_trapezoid(m, m), 449 # torch.triangular_solve(m, m), 450 torch.vdot(v, v), 451 ) 452