1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/ValidateLimitations.h"
8
9 #include "angle_gl.h"
10 #include "compiler/translator/Diagnostics.h"
11 #include "compiler/translator/ParseContext.h"
12 #include "compiler/translator/tree_util/IntermTraverse.h"
13
14 namespace sh
15 {
16
17 namespace
18 {
19
GetLoopSymbolId(TIntermLoop * loop)20 int GetLoopSymbolId(TIntermLoop *loop)
21 {
22 // Here we assume all the operations are valid, because the loop node is
23 // already validated before this call.
24 TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence();
25 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
26 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
27
28 return symbol->uniqueId().get();
29 }
30
31 // Traverses a node to check if it represents a constant index expression.
32 // Definition:
33 // constant-index-expressions are a superset of constant-expressions.
34 // Constant-index-expressions can include loop indices as defined in
35 // GLSL ES 1.0 spec, Appendix A, section 4.
36 // The following are constant-index-expressions:
37 // - Constant expressions
38 // - Loop indices as defined in section 4
39 // - Expressions composed of both of the above
40 class ValidateConstIndexExpr : public TIntermTraverser
41 {
42 public:
ValidateConstIndexExpr(const std::vector<int> & loopSymbols)43 ValidateConstIndexExpr(const std::vector<int> &loopSymbols)
44 : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols)
45 {}
46
47 // Returns true if the parsed node represents a constant index expression.
isValid() const48 bool isValid() const { return mValid; }
49
visitSymbol(TIntermSymbol * symbol)50 void visitSymbol(TIntermSymbol *symbol) override
51 {
52 // Only constants and loop indices are allowed in a
53 // constant index expression.
54 if (mValid)
55 {
56 bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
57 symbol->uniqueId().get()) != mLoopSymbolIds.end();
58 mValid = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
59 }
60 }
61
62 private:
63 bool mValid;
64 const std::vector<int> mLoopSymbolIds;
65 };
66
67 // Traverses intermediate tree to ensure that the shader does not exceed the
68 // minimum functionality mandated in GLSL 1.0 spec, Appendix A.
69 class ValidateLimitationsTraverser : public TLValueTrackingTraverser
70 {
71 public:
72 ValidateLimitationsTraverser(sh::GLenum shaderType,
73 TSymbolTable *symbolTable,
74 TDiagnostics *diagnostics);
75
76 void visitSymbol(TIntermSymbol *node) override;
77 bool visitBinary(Visit, TIntermBinary *) override;
78 bool visitLoop(Visit, TIntermLoop *) override;
79
80 private:
81 void error(TSourceLoc loc, const char *reason, const char *token);
82 void error(TSourceLoc loc, const char *reason, const ImmutableString &token);
83
84 bool isLoopIndex(TIntermSymbol *symbol);
85 bool validateLoopType(TIntermLoop *node);
86
87 bool validateForLoopHeader(TIntermLoop *node);
88 // If valid, return the index symbol id; Otherwise, return -1.
89 int validateForLoopInit(TIntermLoop *node);
90 bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
91 bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
92
93 // Returns true if indexing does not exceed the minimum functionality
94 // mandated in GLSL 1.0 spec, Appendix A, Section 5.
95 bool isConstExpr(TIntermNode *node);
96 bool isConstIndexExpr(TIntermNode *node);
97 bool validateIndexing(TIntermBinary *node);
98
99 sh::GLenum mShaderType;
100 TDiagnostics *mDiagnostics;
101 std::vector<int> mLoopSymbolIds;
102 };
103
ValidateLimitationsTraverser(sh::GLenum shaderType,TSymbolTable * symbolTable,TDiagnostics * diagnostics)104 ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
105 TSymbolTable *symbolTable,
106 TDiagnostics *diagnostics)
107 : TLValueTrackingTraverser(true, false, false, symbolTable),
108 mShaderType(shaderType),
109 mDiagnostics(diagnostics)
110 {
111 ASSERT(diagnostics);
112 }
113
visitSymbol(TIntermSymbol * node)114 void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node)
115 {
116 if (isLoopIndex(node) && isLValueRequiredHere())
117 {
118 error(node->getLine(),
119 "Loop index cannot be statically assigned to within the body of the loop",
120 node->getName());
121 }
122 }
123
visitBinary(Visit,TIntermBinary * node)124 bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node)
125 {
126 // Check indexing.
127 switch (node->getOp())
128 {
129 case EOpIndexDirect:
130 case EOpIndexIndirect:
131 validateIndexing(node);
132 break;
133 default:
134 break;
135 }
136 return true;
137 }
138
visitLoop(Visit,TIntermLoop * node)139 bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node)
140 {
141 if (!validateLoopType(node))
142 return false;
143
144 if (!validateForLoopHeader(node))
145 return false;
146
147 mLoopSymbolIds.push_back(GetLoopSymbolId(node));
148 node->getBody()->traverse(this);
149 mLoopSymbolIds.pop_back();
150
151 // The loop is fully processed - no need to visit children.
152 return false;
153 }
154
error(TSourceLoc loc,const char * reason,const char * token)155 void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token)
156 {
157 mDiagnostics->error(loc, reason, token);
158 }
159
error(TSourceLoc loc,const char * reason,const ImmutableString & token)160 void ValidateLimitationsTraverser::error(TSourceLoc loc,
161 const char *reason,
162 const ImmutableString &token)
163 {
164 error(loc, reason, token.data());
165 }
166
isLoopIndex(TIntermSymbol * symbol)167 bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
168 {
169 return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->uniqueId().get()) !=
170 mLoopSymbolIds.end();
171 }
172
validateLoopType(TIntermLoop * node)173 bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node)
174 {
175 TLoopType type = node->getType();
176 if (type == ELoopFor)
177 return true;
178
179 // Reject while and do-while loops.
180 error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do");
181 return false;
182 }
183
validateForLoopHeader(TIntermLoop * node)184 bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node)
185 {
186 ASSERT(node->getType() == ELoopFor);
187
188 //
189 // The for statement has the form:
190 // for ( init-declaration ; condition ; expression ) statement
191 //
192 int indexSymbolId = validateForLoopInit(node);
193 if (indexSymbolId < 0)
194 return false;
195 if (!validateForLoopCond(node, indexSymbolId))
196 return false;
197 if (!validateForLoopExpr(node, indexSymbolId))
198 return false;
199
200 return true;
201 }
202
validateForLoopInit(TIntermLoop * node)203 int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
204 {
205 TIntermNode *init = node->getInit();
206 if (init == nullptr)
207 {
208 error(node->getLine(), "Missing init declaration", "for");
209 return -1;
210 }
211
212 //
213 // init-declaration has the form:
214 // type-specifier identifier = constant-expression
215 //
216 TIntermDeclaration *decl = init->getAsDeclarationNode();
217 if (decl == nullptr)
218 {
219 error(init->getLine(), "Invalid init declaration", "for");
220 return -1;
221 }
222 // To keep things simple do not allow declaration list.
223 TIntermSequence *declSeq = decl->getSequence();
224 if (declSeq->size() != 1)
225 {
226 error(decl->getLine(), "Invalid init declaration", "for");
227 return -1;
228 }
229 TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
230 if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize))
231 {
232 error(decl->getLine(), "Invalid init declaration", "for");
233 return -1;
234 }
235 TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
236 if (symbol == nullptr)
237 {
238 error(declInit->getLine(), "Invalid init declaration", "for");
239 return -1;
240 }
241 // The loop index has type int or float.
242 TBasicType type = symbol->getBasicType();
243 if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat))
244 {
245 error(symbol->getLine(), "Invalid type for loop index", getBasicString(type));
246 return -1;
247 }
248 // The loop index is initialized with constant expression.
249 if (!isConstExpr(declInit->getRight()))
250 {
251 error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression",
252 symbol->getName());
253 return -1;
254 }
255
256 return symbol->uniqueId().get();
257 }
258
validateForLoopCond(TIntermLoop * node,int indexSymbolId)259 bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
260 {
261 TIntermNode *cond = node->getCondition();
262 if (cond == nullptr)
263 {
264 error(node->getLine(), "Missing condition", "for");
265 return false;
266 }
267 //
268 // condition has the form:
269 // loop_index relational_operator constant_expression
270 //
271 TIntermBinary *binOp = cond->getAsBinaryNode();
272 if (binOp == nullptr)
273 {
274 error(node->getLine(), "Invalid condition", "for");
275 return false;
276 }
277 // Loop index should be to the left of relational operator.
278 TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
279 if (symbol == nullptr)
280 {
281 error(binOp->getLine(), "Invalid condition", "for");
282 return false;
283 }
284 if (symbol->uniqueId().get() != indexSymbolId)
285 {
286 error(symbol->getLine(), "Expected loop index", symbol->getName());
287 return false;
288 }
289 // Relational operator is one of: > >= < <= == or !=.
290 switch (binOp->getOp())
291 {
292 case EOpEqual:
293 case EOpNotEqual:
294 case EOpLessThan:
295 case EOpGreaterThan:
296 case EOpLessThanEqual:
297 case EOpGreaterThanEqual:
298 break;
299 default:
300 error(binOp->getLine(), "Invalid relational operator",
301 GetOperatorString(binOp->getOp()));
302 break;
303 }
304 // Loop index must be compared with a constant.
305 if (!isConstExpr(binOp->getRight()))
306 {
307 error(binOp->getLine(), "Loop index cannot be compared with non-constant expression",
308 symbol->getName());
309 return false;
310 }
311
312 return true;
313 }
314
validateForLoopExpr(TIntermLoop * node,int indexSymbolId)315 bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
316 {
317 TIntermNode *expr = node->getExpression();
318 if (expr == nullptr)
319 {
320 error(node->getLine(), "Missing expression", "for");
321 return false;
322 }
323
324 // for expression has one of the following forms:
325 // loop_index++
326 // loop_index--
327 // loop_index += constant_expression
328 // loop_index -= constant_expression
329 // ++loop_index
330 // --loop_index
331 // The last two forms are not specified in the spec, but I am assuming
332 // its an oversight.
333 TIntermUnary *unOp = expr->getAsUnaryNode();
334 TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode();
335
336 TOperator op = EOpNull;
337 const TFunction *opFunc = nullptr;
338 TIntermSymbol *symbol = nullptr;
339 if (unOp != nullptr)
340 {
341 op = unOp->getOp();
342 opFunc = unOp->getFunction();
343 symbol = unOp->getOperand()->getAsSymbolNode();
344 }
345 else if (binOp != nullptr)
346 {
347 op = binOp->getOp();
348 symbol = binOp->getLeft()->getAsSymbolNode();
349 }
350
351 // The operand must be loop index.
352 if (symbol == nullptr)
353 {
354 error(expr->getLine(), "Invalid expression", "for");
355 return false;
356 }
357 if (symbol->uniqueId().get() != indexSymbolId)
358 {
359 error(symbol->getLine(), "Expected loop index", symbol->getName());
360 return false;
361 }
362
363 // The operator is one of: ++ -- += -=.
364 switch (op)
365 {
366 case EOpPostIncrement:
367 case EOpPostDecrement:
368 case EOpPreIncrement:
369 case EOpPreDecrement:
370 ASSERT((unOp != nullptr) && (binOp == nullptr));
371 break;
372 case EOpAddAssign:
373 case EOpSubAssign:
374 ASSERT((unOp == nullptr) && (binOp != nullptr));
375 break;
376 default:
377 if (BuiltInGroup::IsBuiltIn(op))
378 {
379 ASSERT(opFunc != nullptr);
380 error(expr->getLine(), "Invalid built-in call", opFunc->name().data());
381 }
382 else
383 {
384 error(expr->getLine(), "Invalid operator", GetOperatorString(op));
385 }
386 return false;
387 }
388
389 // Loop index must be incremented/decremented with a constant.
390 if (binOp != nullptr)
391 {
392 if (!isConstExpr(binOp->getRight()))
393 {
394 error(binOp->getLine(), "Loop index cannot be modified by non-constant expression",
395 symbol->getName());
396 return false;
397 }
398 }
399
400 return true;
401 }
402
isConstExpr(TIntermNode * node)403 bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node)
404 {
405 ASSERT(node != nullptr);
406 return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
407 }
408
isConstIndexExpr(TIntermNode * node)409 bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node)
410 {
411 ASSERT(node != nullptr);
412
413 ValidateConstIndexExpr validate(mLoopSymbolIds);
414 node->traverse(&validate);
415 return validate.isValid();
416 }
417
validateIndexing(TIntermBinary * node)418 bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
419 {
420 ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect));
421
422 bool valid = true;
423 TIntermTyped *index = node->getRight();
424 // The index expession must be a constant-index-expression unless
425 // the operand is a uniform in a vertex shader.
426 TIntermTyped *operand = node->getLeft();
427 bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform);
428 if (!skip && !isConstIndexExpr(index))
429 {
430 error(index->getLine(), "Index expression must be constant", "[]");
431 valid = false;
432 }
433 return valid;
434 }
435
436 } // namespace
437
ValidateLimitations(TIntermNode * root,GLenum shaderType,TSymbolTable * symbolTable,TDiagnostics * diagnostics)438 bool ValidateLimitations(TIntermNode *root,
439 GLenum shaderType,
440 TSymbolTable *symbolTable,
441 TDiagnostics *diagnostics)
442 {
443 ValidateLimitationsTraverser validate(shaderType, symbolTable, diagnostics);
444 root->traverse(&validate);
445 return diagnostics->numErrors() == 0;
446 }
447
448 } // namespace sh
449