1 //===- polly/ScheduleTreeTransform.h ----------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Make changes to isl's schedule tree data structure.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef POLLY_SCHEDULETREETRANSFORM_H
14 #define POLLY_SCHEDULETREETRANSFORM_H
15 
16 #include "polly/Support/ISLTools.h"
17 #include "llvm/ADT/ArrayRef.h"
18 #include "llvm/Support/ErrorHandling.h"
19 #include "isl/isl-noexceptions.h"
20 #include <cassert>
21 
22 namespace polly {
23 struct BandAttr;
24 
25 /// This class defines a simple visitor class that may be used for
26 /// various schedule tree analysis purposes.
27 template <typename Derived, typename RetTy = void, typename... Args>
28 struct ScheduleTreeVisitor {
getDerivedScheduleTreeVisitor29   Derived &getDerived() { return *static_cast<Derived *>(this); }
getDerivedScheduleTreeVisitor30   const Derived &getDerived() const {
31     return *static_cast<const Derived *>(this);
32   }
33 
visitScheduleTreeVisitor34   RetTy visit(isl::schedule_node Node, Args... args) {
35     assert(!Node.is_null());
36     switch (isl_schedule_node_get_type(Node.get())) {
37     case isl_schedule_node_domain:
38       assert(isl_schedule_node_n_children(Node.get()) == 1);
39       return getDerived().visitDomain(Node.as<isl::schedule_node_domain>(),
40                                       std::forward<Args>(args)...);
41     case isl_schedule_node_band:
42       assert(isl_schedule_node_n_children(Node.get()) == 1);
43       return getDerived().visitBand(Node.as<isl::schedule_node_band>(),
44                                     std::forward<Args>(args)...);
45     case isl_schedule_node_sequence:
46       assert(isl_schedule_node_n_children(Node.get()) >= 2);
47       return getDerived().visitSequence(Node.as<isl::schedule_node_sequence>(),
48                                         std::forward<Args>(args)...);
49     case isl_schedule_node_set:
50       return getDerived().visitSet(Node.as<isl::schedule_node_set>(),
51                                    std::forward<Args>(args)...);
52       assert(isl_schedule_node_n_children(Node.get()) >= 2);
53     case isl_schedule_node_leaf:
54       assert(isl_schedule_node_n_children(Node.get()) == 0);
55       return getDerived().visitLeaf(Node.as<isl::schedule_node_leaf>(),
56                                     std::forward<Args>(args)...);
57     case isl_schedule_node_mark:
58       assert(isl_schedule_node_n_children(Node.get()) == 1);
59       return getDerived().visitMark(Node.as<isl::schedule_node_mark>(),
60                                     std::forward<Args>(args)...);
61     case isl_schedule_node_extension:
62       assert(isl_schedule_node_n_children(Node.get()) == 1);
63       return getDerived().visitExtension(
64           Node.as<isl::schedule_node_extension>(), std::forward<Args>(args)...);
65     case isl_schedule_node_filter:
66       assert(isl_schedule_node_n_children(Node.get()) == 1);
67       return getDerived().visitFilter(Node.as<isl::schedule_node_filter>(),
68                                       std::forward<Args>(args)...);
69     default:
70       llvm_unreachable("unimplemented schedule node type");
71     }
72   }
73 
visitDomainScheduleTreeVisitor74   RetTy visitDomain(isl::schedule_node_domain Domain, Args... args) {
75     return getDerived().visitSingleChild(std::move(Domain),
76                                          std::forward<Args>(args)...);
77   }
78 
visitBandScheduleTreeVisitor79   RetTy visitBand(isl::schedule_node_band Band, Args... args) {
80     return getDerived().visitSingleChild(std::move(Band),
81                                          std::forward<Args>(args)...);
82   }
83 
visitSequenceScheduleTreeVisitor84   RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args) {
85     return getDerived().visitMultiChild(std::move(Sequence),
86                                         std::forward<Args>(args)...);
87   }
88 
visitSetScheduleTreeVisitor89   RetTy visitSet(isl::schedule_node_set Set, Args... args) {
90     return getDerived().visitMultiChild(std::move(Set),
91                                         std::forward<Args>(args)...);
92   }
93 
visitLeafScheduleTreeVisitor94   RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args) {
95     return getDerived().visitNode(std::move(Leaf), std::forward<Args>(args)...);
96   }
97 
visitMarkScheduleTreeVisitor98   RetTy visitMark(isl::schedule_node_mark Mark, Args... args) {
99     return getDerived().visitSingleChild(std::move(Mark),
100                                          std::forward<Args>(args)...);
101   }
102 
visitExtensionScheduleTreeVisitor103   RetTy visitExtension(isl::schedule_node_extension Extension, Args... args) {
104     return getDerived().visitSingleChild(std::move(Extension),
105                                          std::forward<Args>(args)...);
106   }
107 
visitFilterScheduleTreeVisitor108   RetTy visitFilter(isl::schedule_node_filter Filter, Args... args) {
109     return getDerived().visitSingleChild(std::move(Filter),
110                                          std::forward<Args>(args)...);
111   }
112 
visitSingleChildScheduleTreeVisitor113   RetTy visitSingleChild(isl::schedule_node Node, Args... args) {
114     return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
115   }
116 
visitMultiChildScheduleTreeVisitor117   RetTy visitMultiChild(isl::schedule_node Node, Args... args) {
118     return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...);
119   }
120 
visitNodeScheduleTreeVisitor121   RetTy visitNode(isl::schedule_node Node, Args... args) {
122     llvm_unreachable("Unimplemented other");
123   }
124 };
125 
126 /// Recursively visit all nodes of a schedule tree.
127 template <typename Derived, typename RetTy = void, typename... Args>
128 struct RecursiveScheduleTreeVisitor
129     : ScheduleTreeVisitor<Derived, RetTy, Args...> {
130   using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>;
getBaseRecursiveScheduleTreeVisitor131   BaseTy &getBase() { return *this; }
getBaseRecursiveScheduleTreeVisitor132   const BaseTy &getBase() const { return *this; }
getDerivedRecursiveScheduleTreeVisitor133   Derived &getDerived() { return *static_cast<Derived *>(this); }
getDerivedRecursiveScheduleTreeVisitor134   const Derived &getDerived() const {
135     return *static_cast<const Derived *>(this);
136   }
137 
138   /// When visiting an entire schedule tree, start at its root node.
visitRecursiveScheduleTreeVisitor139   RetTy visit(isl::schedule Schedule, Args... args) {
140     return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...);
141   }
142 
143   // Necessary to allow overload resolution with the added visit(isl::schedule)
144   // overload.
visitRecursiveScheduleTreeVisitor145   RetTy visit(isl::schedule_node Node, Args... args) {
146     return getBase().visit(Node, std::forward<Args>(args)...);
147   }
148 
149   /// By default, recursively visit the child nodes.
visitNodeRecursiveScheduleTreeVisitor150   RetTy visitNode(isl::schedule_node Node, Args... args) {
151     for (unsigned i : rangeIslSize(0, Node.n_children()))
152       getDerived().visit(Node.child(i), std::forward<Args>(args)...);
153     return RetTy();
154   }
155 };
156 
157 /// Recursively visit all nodes of a schedule tree while allowing changes.
158 ///
159 /// The visit methods return an isl::schedule_node that is used to continue
160 /// visiting the tree. Structural changes such as returning a different node
161 /// will confuse the visitor.
162 template <typename Derived, typename... Args>
163 struct ScheduleNodeRewriter
164     : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node,
165                                           Args...> {
getDerivedScheduleNodeRewriter166   Derived &getDerived() { return *static_cast<Derived *>(this); }
getDerivedScheduleNodeRewriter167   const Derived &getDerived() const {
168     return *static_cast<const Derived *>(this);
169   }
170 
visitNodeScheduleNodeRewriter171   isl::schedule_node visitNode(isl::schedule_node Node, Args... args) {
172     return getDerived().visitChildren(Node);
173   }
174 
visitChildrenScheduleNodeRewriter175   isl::schedule_node visitChildren(isl::schedule_node Node, Args... args) {
176     if (!Node.has_children())
177       return Node;
178 
179     isl::schedule_node It = Node.first_child();
180     while (true) {
181       It = getDerived().visit(It, std::forward<Args>(args)...);
182       if (!It.has_next_sibling())
183         break;
184       It = It.next_sibling();
185     }
186     return It.parent();
187   }
188 };
189 
190 /// Is this node the marker for its parent band?
191 bool isBandMark(const isl::schedule_node &Node);
192 
193 /// Extract the BandAttr from a band's wrapping marker. Can also pass the band
194 /// itself and this methods will try to find its wrapping mark. Returns nullptr
195 /// if the band has not BandAttr.
196 BandAttr *getBandAttr(isl::schedule_node MarkOrBand);
197 
198 /// Hoist all domains from extension into the root domain node, such that there
199 /// are no more extension nodes (which isl does not support for some
200 /// operations). This assumes that domains added by to extension nodes do not
201 /// overlap.
202 isl::schedule hoistExtensionNodes(isl::schedule Sched);
203 
204 /// Replace the AST band @p BandToUnroll by a sequence of all its iterations.
205 ///
206 /// The implementation enumerates all points in the partial schedule and creates
207 /// an ISL sequence node for each point. The number of iterations must be a
208 /// constant.
209 isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll);
210 
211 /// Replace the AST band @p BandToUnroll by a partially unrolled equivalent.
212 isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor);
213 
214 /// Loop-distribute the band @p BandToFission as much as possible.
215 isl::schedule applyMaxFission(isl::schedule_node BandToFission);
216 
217 /// Build the desired set of partial tile prefixes.
218 ///
219 /// We build a set of partial tile prefixes, which are prefixes of the vector
220 /// loop that have exactly VectorWidth iterations.
221 ///
222 /// 1. Drop all constraints involving the dimension that represents the
223 ///    vector loop.
224 /// 2. Constrain the last dimension to get a set, which has exactly VectorWidth
225 ///    iterations.
226 /// 3. Subtract loop domain from it, project out the vector loop dimension and
227 ///    get a set that contains prefixes, which do not have exactly VectorWidth
228 ///    iterations.
229 /// 4. Project out the vector loop dimension of the set that was build on the
230 ///    first step and subtract the set built on the previous step to get the
231 ///    desired set of prefixes.
232 ///
233 /// @param ScheduleRange A range of a map, which describes a prefix schedule
234 ///                      relation.
235 isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth);
236 
237 /// Create an isl::union_set, which describes the isolate option based on
238 /// IsolateDomain.
239 ///
240 /// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should
241 ///                      belong to the current band node.
242 /// @param OutDimsNum    A number of dimensions that should belong to
243 ///                      the current band node.
244 isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum);
245 
246 /// Create an isl::union_set, which describes the specified option for the
247 /// dimension of the current node.
248 ///
249 /// @param Ctx    An isl::ctx, which is used to create the isl::union_set.
250 /// @param Option The name of the option.
251 isl::union_set getDimOptions(isl::ctx Ctx, const char *Option);
252 
253 /// Tile a schedule node.
254 ///
255 /// @param Node            The node to tile.
256 /// @param Identifier      An name that identifies this kind of tiling and
257 ///                        that is used to mark the tiled loops in the
258 ///                        generated AST.
259 /// @param TileSizes       A vector of tile sizes that should be used for
260 ///                        tiling.
261 /// @param DefaultTileSize A default tile size that is used for dimensions
262 ///                        that are not covered by the TileSizes vector.
263 isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier,
264                             llvm::ArrayRef<int> TileSizes, int DefaultTileSize);
265 
266 /// Tile a schedule node and unroll point loops.
267 ///
268 /// @param Node            The node to register tile.
269 /// @param TileSizes       A vector of tile sizes that should be used for
270 ///                        tiling.
271 /// @param DefaultTileSize A default tile size that is used for dimensions
272 isl::schedule_node applyRegisterTiling(isl::schedule_node Node,
273                                        llvm::ArrayRef<int> TileSizes,
274                                        int DefaultTileSize);
275 
276 /// Apply greedy fusion. That is, fuse any loop that is possible to be fused
277 /// top-down.
278 ///
279 /// @param Sched  Sched tree to fuse all the loops in.
280 /// @param Deps   Validity constraints that must be preserved.
281 isl::schedule applyGreedyFusion(isl::schedule Sched,
282                                 const isl::union_map &Deps);
283 
284 } // namespace polly
285 
286 #endif // POLLY_SCHEDULETREETRANSFORM_H
287