xref: /aosp_15_r20/external/skia/src/sksl/analysis/SkSLGetLoopUnrollInfo.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "include/core/SkTypes.h"
9 #include "include/private/base/SkFloatingPoint.h"
10 #include "src/sksl/SkSLAnalysis.h"
11 #include "src/sksl/SkSLConstantFolder.h"
12 #include "src/sksl/SkSLErrorReporter.h"
13 #include "src/sksl/SkSLOperator.h"
14 #include "src/sksl/SkSLPosition.h"
15 #include "src/sksl/analysis/SkSLNoOpErrorReporter.h"
16 #include "src/sksl/ir/SkSLBinaryExpression.h"
17 #include "src/sksl/ir/SkSLExpression.h"
18 #include "src/sksl/ir/SkSLForStatement.h"
19 #include "src/sksl/ir/SkSLIRNode.h"
20 #include "src/sksl/ir/SkSLPostfixExpression.h"
21 #include "src/sksl/ir/SkSLPrefixExpression.h"
22 #include "src/sksl/ir/SkSLStatement.h"
23 #include "src/sksl/ir/SkSLType.h"
24 #include "src/sksl/ir/SkSLVarDeclarations.h"
25 #include "src/sksl/ir/SkSLVariable.h"
26 #include "src/sksl/ir/SkSLVariableReference.h"
27 
28 #include <cmath>
29 #include <memory>
30 
31 namespace SkSL {
32 
33 class Context;
34 
35 // Loops that run for 100000+ iterations will exceed our program size limit.
36 static constexpr int kLoopTerminationLimit = 100000;
37 
calculate_count(double start,double end,double delta,bool forwards,bool inclusive)38 static int calculate_count(double start, double end, double delta, bool forwards, bool inclusive) {
39     if ((forwards && start > end) || (!forwards && start < end)) {
40         // The loop starts in a completed state (the start has already advanced past the end).
41         return 0;
42     }
43     if ((delta == 0.0) || forwards != (delta > 0.0)) {
44         // The loop does not progress toward a completed state, and will never terminate.
45         return kLoopTerminationLimit;
46     }
47     double iterations = sk_ieee_double_divide(end - start, delta);
48     double count = std::ceil(iterations);
49     if (inclusive && (count == iterations)) {
50         count += 1.0;
51     }
52     if (count > kLoopTerminationLimit || !std::isfinite(count)) {
53         // The loop runs for more iterations than we can safely unroll.
54         return kLoopTerminationLimit;
55     }
56     return (int)count;
57 }
58 
GetLoopUnrollInfo(const Context & context,Position loopPos,const ForLoopPositions & positions,const Statement * loopInitializer,std::unique_ptr<Expression> * loopTest,const Expression * loopNext,const Statement * loopStatement,ErrorReporter * errorPtr)59 std::unique_ptr<LoopUnrollInfo> Analysis::GetLoopUnrollInfo(const Context& context,
60                                                             Position loopPos,
61                                                             const ForLoopPositions& positions,
62                                                             const Statement* loopInitializer,
63                                                             std::unique_ptr<Expression>* loopTest,
64                                                             const Expression* loopNext,
65                                                             const Statement* loopStatement,
66                                                             ErrorReporter* errorPtr) {
67     NoOpErrorReporter unused;
68     ErrorReporter& errors = errorPtr ? *errorPtr : unused;
69 
70     auto loopInfo = std::make_unique<LoopUnrollInfo>();
71 
72     //
73     // init_declaration has the form: type_specifier identifier = constant_expression
74     //
75     if (!loopInitializer) {
76         Position pos = positions.initPosition.valid() ? positions.initPosition : loopPos;
77         errors.error(pos, "missing init declaration");
78         return nullptr;
79     }
80     if (!loopInitializer->is<VarDeclaration>()) {
81         errors.error(loopInitializer->fPosition, "invalid init declaration");
82         return nullptr;
83     }
84     const VarDeclaration& initDecl = loopInitializer->as<VarDeclaration>();
85     if (!initDecl.baseType().isNumber()) {
86         errors.error(loopInitializer->fPosition, "invalid type for loop index");
87         return nullptr;
88     }
89     if (initDecl.arraySize() != 0) {
90         errors.error(loopInitializer->fPosition, "invalid type for loop index");
91         return nullptr;
92     }
93     if (!initDecl.value()) {
94         errors.error(loopInitializer->fPosition, "missing loop index initializer");
95         return nullptr;
96     }
97     if (!ConstantFolder::GetConstantValue(*initDecl.value(), &loopInfo->fStart)) {
98         errors.error(loopInitializer->fPosition,
99                      "loop index initializer must be a constant expression");
100         return nullptr;
101     }
102 
103     loopInfo->fIndex = initDecl.var();
104 
105     auto is_loop_index = [&](const std::unique_ptr<Expression>& expr) {
106         return expr->is<VariableReference>() &&
107                expr->as<VariableReference>().variable() == loopInfo->fIndex;
108     };
109 
110     //
111     // condition has the form: loop_index relational_operator constant_expression
112     //
113     if (!loopTest || !*loopTest) {
114         Position pos = positions.conditionPosition.valid() ? positions.conditionPosition : loopPos;
115         errors.error(pos, "missing condition");
116         return nullptr;
117     }
118     if (!loopTest->get()->is<BinaryExpression>()) {
119         errors.error(loopTest->get()->fPosition, "invalid condition");
120         return nullptr;
121     }
122     const BinaryExpression* cond = &loopTest->get()->as<BinaryExpression>();
123     if (!is_loop_index(cond->left())) {
124         errors.error(cond->fPosition, "expected loop index on left hand side of condition");
125         return nullptr;
126     }
127     // relational_operator is one of: > >= < <= == or !=
128     switch (cond->getOperator().kind()) {
129         case Operator::Kind::GT:
130         case Operator::Kind::GTEQ:
131         case Operator::Kind::LT:
132         case Operator::Kind::LTEQ:
133         case Operator::Kind::EQEQ:
134         case Operator::Kind::NEQ:
135             break;
136         default:
137             errors.error(cond->fPosition, "invalid relational operator");
138             return nullptr;
139     }
140     double loopEnd = 0;
141     if (!ConstantFolder::GetConstantValue(*cond->right(), &loopEnd)) {
142         errors.error(cond->fPosition, "loop index must be compared with a constant expression");
143         return nullptr;
144     }
145 
146     //
147     // expression has one of the following forms:
148     //   loop_index++
149     //   loop_index--
150     //   loop_index += constant_expression
151     //   loop_index -= constant_expression
152     // The spec doesn't mention prefix increment and decrement, but there is some consensus that
153     // it's an oversight, so we allow those as well.
154     //
155     if (!loopNext) {
156         Position pos = positions.nextPosition.valid() ? positions.nextPosition : loopPos;
157         errors.error(pos, "missing loop expression");
158         return nullptr;
159     }
160     switch (loopNext->kind()) {
161         case Expression::Kind::kBinary: {
162             const BinaryExpression& next = loopNext->as<BinaryExpression>();
163             if (!is_loop_index(next.left())) {
164                 errors.error(loopNext->fPosition, "expected loop index in loop expression");
165                 return nullptr;
166             }
167             if (!ConstantFolder::GetConstantValue(*next.right(), &loopInfo->fDelta)) {
168                 errors.error(loopNext->fPosition,
169                              "loop index must be modified by a constant expression");
170                 return nullptr;
171             }
172             switch (next.getOperator().kind()) {
173                 case Operator::Kind::PLUSEQ:                                        break;
174                 case Operator::Kind::MINUSEQ: loopInfo->fDelta = -loopInfo->fDelta; break;
175                 default:
176                     errors.error(loopNext->fPosition, "invalid operator in loop expression");
177                     return nullptr;
178             }
179             break;
180         }
181         case Expression::Kind::kPrefix: {
182             const PrefixExpression& next = loopNext->as<PrefixExpression>();
183             if (!is_loop_index(next.operand())) {
184                 errors.error(loopNext->fPosition, "expected loop index in loop expression");
185                 return nullptr;
186             }
187             switch (next.getOperator().kind()) {
188                 case Operator::Kind::PLUSPLUS:   loopInfo->fDelta =  1; break;
189                 case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break;
190                 default:
191                     errors.error(loopNext->fPosition, "invalid operator in loop expression");
192                     return nullptr;
193             }
194             break;
195         }
196         case Expression::Kind::kPostfix: {
197             const PostfixExpression& next = loopNext->as<PostfixExpression>();
198             if (!is_loop_index(next.operand())) {
199                 errors.error(loopNext->fPosition, "expected loop index in loop expression");
200                 return nullptr;
201             }
202             switch (next.getOperator().kind()) {
203                 case Operator::Kind::PLUSPLUS:   loopInfo->fDelta =  1; break;
204                 case Operator::Kind::MINUSMINUS: loopInfo->fDelta = -1; break;
205                 default:
206                     errors.error(loopNext->fPosition, "invalid operator in loop expression");
207                     return nullptr;
208             }
209             break;
210         }
211         default:
212             errors.error(loopNext->fPosition, "invalid loop expression");
213             return nullptr;
214     }
215 
216     //
217     // Within the body of the loop, the loop index is not statically assigned to, nor is it used as
218     // argument to a function 'out' or 'inout' parameter.
219     //
220     if (Analysis::StatementWritesToVariable(*loopStatement, *initDecl.var())) {
221         errors.error(loopStatement->fPosition,
222                      "loop index must not be modified within body of the loop");
223         return nullptr;
224     }
225 
226     // Finally, compute the iteration count, based on the bounds, and the termination operator.
227     loopInfo->fCount = 0;
228 
229     switch (cond->getOperator().kind()) {
230         case Operator::Kind::LT:
231             loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
232                                               /*forwards=*/true, /*inclusive=*/false);
233             break;
234 
235         case Operator::Kind::GT:
236             loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
237                                               /*forwards=*/false, /*inclusive=*/false);
238             break;
239 
240         case Operator::Kind::LTEQ:
241             loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
242                                               /*forwards=*/true, /*inclusive=*/true);
243             break;
244 
245         case Operator::Kind::GTEQ:
246             loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
247                                               /*forwards=*/false, /*inclusive=*/true);
248             break;
249 
250         case Operator::Kind::NEQ: {
251             float iterations = sk_ieee_double_divide(loopEnd - loopInfo->fStart, loopInfo->fDelta);
252             loopInfo->fCount = std::ceil(iterations);
253             if (loopInfo->fCount < 0 || loopInfo->fCount != iterations ||
254                 !std::isfinite(iterations)) {
255                 // The loop doesn't reach the exact endpoint and so will never terminate.
256                 loopInfo->fCount = kLoopTerminationLimit;
257             }
258             if (loopInfo->fIndex->type().componentType().isFloat()) {
259                 // Rewrite `x != n` tests as `x < n` or `x > n` depending on the loop direction.
260                 // Less-than and greater-than tests avoid infinite loops caused by rounding error.
261                 Operator::Kind op = (loopInfo->fDelta > 0) ? Operator::Kind::LT
262                                                            : Operator::Kind::GT;
263                 *loopTest = BinaryExpression::Make(context,
264                                                    cond->fPosition,
265                                                    cond->left()->clone(),
266                                                    op,
267                                                    cond->right()->clone());
268                 cond = &loopTest->get()->as<BinaryExpression>();
269             }
270             break;
271         }
272         case Operator::Kind::EQEQ: {
273             if (loopInfo->fStart == loopEnd) {
274                 // Start and end begin in the same place, so we can run one iteration...
275                 if (loopInfo->fDelta) {
276                     // ... and then they diverge, so the loop terminates.
277                     loopInfo->fCount = 1;
278                 } else {
279                     // ... but they never diverge, so the loop runs forever.
280                     loopInfo->fCount = kLoopTerminationLimit;
281                 }
282             } else {
283                 // Start never equals end, so the loop will not run a single iteration.
284                 loopInfo->fCount = 0;
285             }
286             break;
287         }
288         default: SkUNREACHABLE;
289     }
290 
291     SkASSERT(loopInfo->fCount >= 0);
292     if (loopInfo->fCount >= kLoopTerminationLimit) {
293         errors.error(loopPos, "loop must guarantee termination in fewer iterations");
294         return nullptr;
295     }
296 
297     return loopInfo;
298 }
299 
300 }  // namespace SkSL
301