xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import math
2
3import torch
4
5
6def rounded_linspace(low, high, steps, div):
7    ret = torch.linspace(low, high, steps)
8    ret = (ret.int() + div - 1) // div * div
9    ret = torch.unique(ret)
10    return list(map(int, ret))
11
12
13def powspace(start, stop, pow, step):
14    start = math.log(start, pow)
15    stop = math.log(stop, pow)
16    steps = int((stop - start + 1) // step)
17    ret = torch.pow(pow, torch.linspace(start, stop, steps))
18    ret = torch.unique(ret)
19    return list(map(int, ret))
20