xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/peephole_dict_idioms.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit {
6 
7 // Peephole Optimizes Dict Ops such as len() and __getitem__
8 // 1. getitem optimizations
9 // Given a function like this:
10 //     def foo():
11 //         d = {0 : 1}
12 //         x = d[0]
13 //         return x
14 // This pass produces (after dead code elimination):
15 //     def foo(a, b):
16 //         return 1
17 //
18 // This optimization can only happen if the dict is not modified
19 // and the dict has constant, non overlapping keys.
20 //
21 // 2. len optimizations
22 // Given a function like this:
23 //     def foo():
24 //         d = {0 : 1}
25 //         return len(d)
26 // This pass produces (after dead code elimination):
27 //     def foo():
28 //         return 1
29 //
30 // This has the same requirements as the getitem optimizations.
31 //
32 // Currently this is invoked as part of PeepholeOptimize
33 // return true if graph is modified.
34 TORCH_API bool PeepholeOptimizeDictIdioms(const std::shared_ptr<Graph>& graph);
35 
36 } // namespace torch::jit
37