1import timeit 2 3import torch.fx 4 5 6N = 100000 7K = 1000 8 9 10def huge_graph(): 11 def fn(x): 12 for _ in range(N): 13 x = x.sin() 14 return x 15 16 return torch.fx.symbolic_trace(fn) 17 18 19def main(): 20 g = huge_graph() 21 22 def fn(): 23 for n in g.graph.nodes: 24 pass 25 26 t = min(timeit.repeat(fn, number=K, repeat=3)) 27 print(f"iterating over {N*K} FX nodes took {t:.1f}s ({N*K/t:.0f} nodes/s)") 28 29 30if __name__ == "__main__": 31 main() 32