xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/peephole_list_idioms.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit {
6 
7 // Peephole Optimizes List ops such as len(li) and li[1].
8 // 1. Construct/Unpack optimizations
9 // Given a function like this:
10 //    def foo(a, b):
11 //        li = [a, b]
12 //        x, y = li
13 //        return x, y
14 // This pass produces (after dead code elimination):
15 //    def foo(a, b):
16 //        return a, b
17 //
18 // This is only applied to lists that are not modified.
19 //
20 // 2. getitem optimizations
21 // Given a function like this:
22 //     def foo(a, b):
23 //         li = [a, b]
24 //         x = li[0]
25 //         return x
26 // This pass produces (after dead code elimination):
27 //     def foo(a, b):
28 //         return a
29 //
30 // This optimization can only happen if the list is not modified.
31 //
32 // 3. len optimizations
33 // Given a function like this:
34 //     def foo():
35 //         li = [1, 2]
36 //         return len(li)
37 // This pass produces (after dead code elimination):
38 //     def foo():
39 //         return 2
40 //
41 // This has the same requirements as the getitem optimizations.
42 //
43 // 4. ListConstruct + ListConstruct
44 // Given a function like this:
45 //     def foo():
46 //         return [1, 2] + [3, 4]
47 // This pass produces (after dead code elimination):
48 //     def foo():
49 //         return [1, 2, 3, 4]
50 //
51 // This is only applied to lists that are not modified.
52 //
53 // 5. Slice
54 // Given a function like this:
55 //     def foo():
56 //         return [1, 2, 3, 4, 5][0:2]
57 // This pass produces (after deadcode elimination):
58 //     def foo():
59 //         return [1, 2]
60 //
61 // Currently this is invoked as part of PeepholeOptimize
62 // return true if graph is modified.
63 // If `refine_list_len` is true will attempt to refine the len of lists through
64 // len comparisons and assertions. This does not generally optimize pytorch
65 // programs so it is not called by default in PeepholeOptimize.
66 TORCH_API bool PeepholeOptimizeListIdioms(
67     const std::shared_ptr<Graph>& graph,
68     bool refine_list_len = false);
69 
70 } // namespace torch::jit
71