1 //===- polly/ScheduleTreeTransform.h ----------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Make changes to isl's schedule tree data structure. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef POLLY_SCHEDULETREETRANSFORM_H 14 #define POLLY_SCHEDULETREETRANSFORM_H 15 16 #include "polly/Support/ISLTools.h" 17 #include "llvm/ADT/ArrayRef.h" 18 #include "llvm/Support/ErrorHandling.h" 19 #include "isl/isl-noexceptions.h" 20 #include <cassert> 21 22 namespace polly { 23 struct BandAttr; 24 25 /// This class defines a simple visitor class that may be used for 26 /// various schedule tree analysis purposes. 27 template <typename Derived, typename RetTy = void, typename... Args> 28 struct ScheduleTreeVisitor { getDerivedScheduleTreeVisitor29 Derived &getDerived() { return *static_cast<Derived *>(this); } getDerivedScheduleTreeVisitor30 const Derived &getDerived() const { 31 return *static_cast<const Derived *>(this); 32 } 33 visitScheduleTreeVisitor34 RetTy visit(isl::schedule_node Node, Args... args) { 35 assert(!Node.is_null()); 36 switch (isl_schedule_node_get_type(Node.get())) { 37 case isl_schedule_node_domain: 38 assert(isl_schedule_node_n_children(Node.get()) == 1); 39 return getDerived().visitDomain(Node.as<isl::schedule_node_domain>(), 40 std::forward<Args>(args)...); 41 case isl_schedule_node_band: 42 assert(isl_schedule_node_n_children(Node.get()) == 1); 43 return getDerived().visitBand(Node.as<isl::schedule_node_band>(), 44 std::forward<Args>(args)...); 45 case isl_schedule_node_sequence: 46 assert(isl_schedule_node_n_children(Node.get()) >= 2); 47 return getDerived().visitSequence(Node.as<isl::schedule_node_sequence>(), 48 std::forward<Args>(args)...); 49 case isl_schedule_node_set: 50 return getDerived().visitSet(Node.as<isl::schedule_node_set>(), 51 std::forward<Args>(args)...); 52 assert(isl_schedule_node_n_children(Node.get()) >= 2); 53 case isl_schedule_node_leaf: 54 assert(isl_schedule_node_n_children(Node.get()) == 0); 55 return getDerived().visitLeaf(Node.as<isl::schedule_node_leaf>(), 56 std::forward<Args>(args)...); 57 case isl_schedule_node_mark: 58 assert(isl_schedule_node_n_children(Node.get()) == 1); 59 return getDerived().visitMark(Node.as<isl::schedule_node_mark>(), 60 std::forward<Args>(args)...); 61 case isl_schedule_node_extension: 62 assert(isl_schedule_node_n_children(Node.get()) == 1); 63 return getDerived().visitExtension( 64 Node.as<isl::schedule_node_extension>(), std::forward<Args>(args)...); 65 case isl_schedule_node_filter: 66 assert(isl_schedule_node_n_children(Node.get()) == 1); 67 return getDerived().visitFilter(Node.as<isl::schedule_node_filter>(), 68 std::forward<Args>(args)...); 69 default: 70 llvm_unreachable("unimplemented schedule node type"); 71 } 72 } 73 visitDomainScheduleTreeVisitor74 RetTy visitDomain(isl::schedule_node_domain Domain, Args... args) { 75 return getDerived().visitSingleChild(std::move(Domain), 76 std::forward<Args>(args)...); 77 } 78 visitBandScheduleTreeVisitor79 RetTy visitBand(isl::schedule_node_band Band, Args... args) { 80 return getDerived().visitSingleChild(std::move(Band), 81 std::forward<Args>(args)...); 82 } 83 visitSequenceScheduleTreeVisitor84 RetTy visitSequence(isl::schedule_node_sequence Sequence, Args... args) { 85 return getDerived().visitMultiChild(std::move(Sequence), 86 std::forward<Args>(args)...); 87 } 88 visitSetScheduleTreeVisitor89 RetTy visitSet(isl::schedule_node_set Set, Args... args) { 90 return getDerived().visitMultiChild(std::move(Set), 91 std::forward<Args>(args)...); 92 } 93 visitLeafScheduleTreeVisitor94 RetTy visitLeaf(isl::schedule_node_leaf Leaf, Args... args) { 95 return getDerived().visitNode(std::move(Leaf), std::forward<Args>(args)...); 96 } 97 visitMarkScheduleTreeVisitor98 RetTy visitMark(isl::schedule_node_mark Mark, Args... args) { 99 return getDerived().visitSingleChild(std::move(Mark), 100 std::forward<Args>(args)...); 101 } 102 visitExtensionScheduleTreeVisitor103 RetTy visitExtension(isl::schedule_node_extension Extension, Args... args) { 104 return getDerived().visitSingleChild(std::move(Extension), 105 std::forward<Args>(args)...); 106 } 107 visitFilterScheduleTreeVisitor108 RetTy visitFilter(isl::schedule_node_filter Filter, Args... args) { 109 return getDerived().visitSingleChild(std::move(Filter), 110 std::forward<Args>(args)...); 111 } 112 visitSingleChildScheduleTreeVisitor113 RetTy visitSingleChild(isl::schedule_node Node, Args... args) { 114 return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...); 115 } 116 visitMultiChildScheduleTreeVisitor117 RetTy visitMultiChild(isl::schedule_node Node, Args... args) { 118 return getDerived().visitNode(std::move(Node), std::forward<Args>(args)...); 119 } 120 visitNodeScheduleTreeVisitor121 RetTy visitNode(isl::schedule_node Node, Args... args) { 122 llvm_unreachable("Unimplemented other"); 123 } 124 }; 125 126 /// Recursively visit all nodes of a schedule tree. 127 template <typename Derived, typename RetTy = void, typename... Args> 128 struct RecursiveScheduleTreeVisitor 129 : ScheduleTreeVisitor<Derived, RetTy, Args...> { 130 using BaseTy = ScheduleTreeVisitor<Derived, RetTy, Args...>; getBaseRecursiveScheduleTreeVisitor131 BaseTy &getBase() { return *this; } getBaseRecursiveScheduleTreeVisitor132 const BaseTy &getBase() const { return *this; } getDerivedRecursiveScheduleTreeVisitor133 Derived &getDerived() { return *static_cast<Derived *>(this); } getDerivedRecursiveScheduleTreeVisitor134 const Derived &getDerived() const { 135 return *static_cast<const Derived *>(this); 136 } 137 138 /// When visiting an entire schedule tree, start at its root node. visitRecursiveScheduleTreeVisitor139 RetTy visit(isl::schedule Schedule, Args... args) { 140 return getDerived().visit(Schedule.get_root(), std::forward<Args>(args)...); 141 } 142 143 // Necessary to allow overload resolution with the added visit(isl::schedule) 144 // overload. visitRecursiveScheduleTreeVisitor145 RetTy visit(isl::schedule_node Node, Args... args) { 146 return getBase().visit(Node, std::forward<Args>(args)...); 147 } 148 149 /// By default, recursively visit the child nodes. visitNodeRecursiveScheduleTreeVisitor150 RetTy visitNode(isl::schedule_node Node, Args... args) { 151 for (unsigned i : rangeIslSize(0, Node.n_children())) 152 getDerived().visit(Node.child(i), std::forward<Args>(args)...); 153 return RetTy(); 154 } 155 }; 156 157 /// Recursively visit all nodes of a schedule tree while allowing changes. 158 /// 159 /// The visit methods return an isl::schedule_node that is used to continue 160 /// visiting the tree. Structural changes such as returning a different node 161 /// will confuse the visitor. 162 template <typename Derived, typename... Args> 163 struct ScheduleNodeRewriter 164 : public RecursiveScheduleTreeVisitor<Derived, isl::schedule_node, 165 Args...> { getDerivedScheduleNodeRewriter166 Derived &getDerived() { return *static_cast<Derived *>(this); } getDerivedScheduleNodeRewriter167 const Derived &getDerived() const { 168 return *static_cast<const Derived *>(this); 169 } 170 visitNodeScheduleNodeRewriter171 isl::schedule_node visitNode(isl::schedule_node Node, Args... args) { 172 return getDerived().visitChildren(Node); 173 } 174 visitChildrenScheduleNodeRewriter175 isl::schedule_node visitChildren(isl::schedule_node Node, Args... args) { 176 if (!Node.has_children()) 177 return Node; 178 179 isl::schedule_node It = Node.first_child(); 180 while (true) { 181 It = getDerived().visit(It, std::forward<Args>(args)...); 182 if (!It.has_next_sibling()) 183 break; 184 It = It.next_sibling(); 185 } 186 return It.parent(); 187 } 188 }; 189 190 /// Is this node the marker for its parent band? 191 bool isBandMark(const isl::schedule_node &Node); 192 193 /// Extract the BandAttr from a band's wrapping marker. Can also pass the band 194 /// itself and this methods will try to find its wrapping mark. Returns nullptr 195 /// if the band has not BandAttr. 196 BandAttr *getBandAttr(isl::schedule_node MarkOrBand); 197 198 /// Hoist all domains from extension into the root domain node, such that there 199 /// are no more extension nodes (which isl does not support for some 200 /// operations). This assumes that domains added by to extension nodes do not 201 /// overlap. 202 isl::schedule hoistExtensionNodes(isl::schedule Sched); 203 204 /// Replace the AST band @p BandToUnroll by a sequence of all its iterations. 205 /// 206 /// The implementation enumerates all points in the partial schedule and creates 207 /// an ISL sequence node for each point. The number of iterations must be a 208 /// constant. 209 isl::schedule applyFullUnroll(isl::schedule_node BandToUnroll); 210 211 /// Replace the AST band @p BandToUnroll by a partially unrolled equivalent. 212 isl::schedule applyPartialUnroll(isl::schedule_node BandToUnroll, int Factor); 213 214 /// Loop-distribute the band @p BandToFission as much as possible. 215 isl::schedule applyMaxFission(isl::schedule_node BandToFission); 216 217 /// Build the desired set of partial tile prefixes. 218 /// 219 /// We build a set of partial tile prefixes, which are prefixes of the vector 220 /// loop that have exactly VectorWidth iterations. 221 /// 222 /// 1. Drop all constraints involving the dimension that represents the 223 /// vector loop. 224 /// 2. Constrain the last dimension to get a set, which has exactly VectorWidth 225 /// iterations. 226 /// 3. Subtract loop domain from it, project out the vector loop dimension and 227 /// get a set that contains prefixes, which do not have exactly VectorWidth 228 /// iterations. 229 /// 4. Project out the vector loop dimension of the set that was build on the 230 /// first step and subtract the set built on the previous step to get the 231 /// desired set of prefixes. 232 /// 233 /// @param ScheduleRange A range of a map, which describes a prefix schedule 234 /// relation. 235 isl::set getPartialTilePrefixes(isl::set ScheduleRange, int VectorWidth); 236 237 /// Create an isl::union_set, which describes the isolate option based on 238 /// IsolateDomain. 239 /// 240 /// @param IsolateDomain An isl::set whose @p OutDimsNum last dimensions should 241 /// belong to the current band node. 242 /// @param OutDimsNum A number of dimensions that should belong to 243 /// the current band node. 244 isl::union_set getIsolateOptions(isl::set IsolateDomain, unsigned OutDimsNum); 245 246 /// Create an isl::union_set, which describes the specified option for the 247 /// dimension of the current node. 248 /// 249 /// @param Ctx An isl::ctx, which is used to create the isl::union_set. 250 /// @param Option The name of the option. 251 isl::union_set getDimOptions(isl::ctx Ctx, const char *Option); 252 253 /// Tile a schedule node. 254 /// 255 /// @param Node The node to tile. 256 /// @param Identifier An name that identifies this kind of tiling and 257 /// that is used to mark the tiled loops in the 258 /// generated AST. 259 /// @param TileSizes A vector of tile sizes that should be used for 260 /// tiling. 261 /// @param DefaultTileSize A default tile size that is used for dimensions 262 /// that are not covered by the TileSizes vector. 263 isl::schedule_node tileNode(isl::schedule_node Node, const char *Identifier, 264 llvm::ArrayRef<int> TileSizes, int DefaultTileSize); 265 266 /// Tile a schedule node and unroll point loops. 267 /// 268 /// @param Node The node to register tile. 269 /// @param TileSizes A vector of tile sizes that should be used for 270 /// tiling. 271 /// @param DefaultTileSize A default tile size that is used for dimensions 272 isl::schedule_node applyRegisterTiling(isl::schedule_node Node, 273 llvm::ArrayRef<int> TileSizes, 274 int DefaultTileSize); 275 276 /// Apply greedy fusion. That is, fuse any loop that is possible to be fused 277 /// top-down. 278 /// 279 /// @param Sched Sched tree to fuse all the loops in. 280 /// @param Deps Validity constraints that must be preserved. 281 isl::schedule applyGreedyFusion(isl::schedule Sched, 282 const isl::union_map &Deps); 283 284 } // namespace polly 285 286 #endif // POLLY_SCHEDULETREETRANSFORM_H 287