xref: /aosp_15_r20/external/pytorch/benchmarks/dynamo/microbenchmarks/fx_microbenchmarks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import timeit
2
3import torch.fx
4
5
6N = 100000
7K = 1000
8
9
10def huge_graph():
11    def fn(x):
12        for _ in range(N):
13            x = x.sin()
14        return x
15
16    return torch.fx.symbolic_trace(fn)
17
18
19def main():
20    g = huge_graph()
21
22    def fn():
23        for n in g.graph.nodes:
24            pass
25
26    t = min(timeit.repeat(fn, number=K, repeat=3))
27    print(f"iterating over {N*K} FX nodes took {t:.1f}s ({N*K/t:.0f} nodes/s)")
28
29
30if __name__ == "__main__":
31    main()
32