xref: /aosp_15_r20/external/angle/src/compiler/translator/ValidateLimitations.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/ValidateLimitations.h"
8 
9 #include "angle_gl.h"
10 #include "compiler/translator/Diagnostics.h"
11 #include "compiler/translator/ParseContext.h"
12 #include "compiler/translator/tree_util/IntermTraverse.h"
13 
14 namespace sh
15 {
16 
17 namespace
18 {
19 
GetLoopSymbolId(TIntermLoop * loop)20 int GetLoopSymbolId(TIntermLoop *loop)
21 {
22     // Here we assume all the operations are valid, because the loop node is
23     // already validated before this call.
24     TIntermSequence *declSeq = loop->getInit()->getAsDeclarationNode()->getSequence();
25     TIntermBinary *declInit  = (*declSeq)[0]->getAsBinaryNode();
26     TIntermSymbol *symbol    = declInit->getLeft()->getAsSymbolNode();
27 
28     return symbol->uniqueId().get();
29 }
30 
31 // Traverses a node to check if it represents a constant index expression.
32 // Definition:
33 // constant-index-expressions are a superset of constant-expressions.
34 // Constant-index-expressions can include loop indices as defined in
35 // GLSL ES 1.0 spec, Appendix A, section 4.
36 // The following are constant-index-expressions:
37 // - Constant expressions
38 // - Loop indices as defined in section 4
39 // - Expressions composed of both of the above
40 class ValidateConstIndexExpr : public TIntermTraverser
41 {
42   public:
ValidateConstIndexExpr(const std::vector<int> & loopSymbols)43     ValidateConstIndexExpr(const std::vector<int> &loopSymbols)
44         : TIntermTraverser(true, false, false), mValid(true), mLoopSymbolIds(loopSymbols)
45     {}
46 
47     // Returns true if the parsed node represents a constant index expression.
isValid() const48     bool isValid() const { return mValid; }
49 
visitSymbol(TIntermSymbol * symbol)50     void visitSymbol(TIntermSymbol *symbol) override
51     {
52         // Only constants and loop indices are allowed in a
53         // constant index expression.
54         if (mValid)
55         {
56             bool isLoopSymbol = std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(),
57                                           symbol->uniqueId().get()) != mLoopSymbolIds.end();
58             mValid            = (symbol->getQualifier() == EvqConst) || isLoopSymbol;
59         }
60     }
61 
62   private:
63     bool mValid;
64     const std::vector<int> mLoopSymbolIds;
65 };
66 
67 // Traverses intermediate tree to ensure that the shader does not exceed the
68 // minimum functionality mandated in GLSL 1.0 spec, Appendix A.
69 class ValidateLimitationsTraverser : public TLValueTrackingTraverser
70 {
71   public:
72     ValidateLimitationsTraverser(sh::GLenum shaderType,
73                                  TSymbolTable *symbolTable,
74                                  TDiagnostics *diagnostics);
75 
76     void visitSymbol(TIntermSymbol *node) override;
77     bool visitBinary(Visit, TIntermBinary *) override;
78     bool visitLoop(Visit, TIntermLoop *) override;
79 
80   private:
81     void error(TSourceLoc loc, const char *reason, const char *token);
82     void error(TSourceLoc loc, const char *reason, const ImmutableString &token);
83 
84     bool isLoopIndex(TIntermSymbol *symbol);
85     bool validateLoopType(TIntermLoop *node);
86 
87     bool validateForLoopHeader(TIntermLoop *node);
88     // If valid, return the index symbol id; Otherwise, return -1.
89     int validateForLoopInit(TIntermLoop *node);
90     bool validateForLoopCond(TIntermLoop *node, int indexSymbolId);
91     bool validateForLoopExpr(TIntermLoop *node, int indexSymbolId);
92 
93     // Returns true if indexing does not exceed the minimum functionality
94     // mandated in GLSL 1.0 spec, Appendix A, Section 5.
95     bool isConstExpr(TIntermNode *node);
96     bool isConstIndexExpr(TIntermNode *node);
97     bool validateIndexing(TIntermBinary *node);
98 
99     sh::GLenum mShaderType;
100     TDiagnostics *mDiagnostics;
101     std::vector<int> mLoopSymbolIds;
102 };
103 
ValidateLimitationsTraverser(sh::GLenum shaderType,TSymbolTable * symbolTable,TDiagnostics * diagnostics)104 ValidateLimitationsTraverser::ValidateLimitationsTraverser(sh::GLenum shaderType,
105                                                            TSymbolTable *symbolTable,
106                                                            TDiagnostics *diagnostics)
107     : TLValueTrackingTraverser(true, false, false, symbolTable),
108       mShaderType(shaderType),
109       mDiagnostics(diagnostics)
110 {
111     ASSERT(diagnostics);
112 }
113 
visitSymbol(TIntermSymbol * node)114 void ValidateLimitationsTraverser::visitSymbol(TIntermSymbol *node)
115 {
116     if (isLoopIndex(node) && isLValueRequiredHere())
117     {
118         error(node->getLine(),
119               "Loop index cannot be statically assigned to within the body of the loop",
120               node->getName());
121     }
122 }
123 
visitBinary(Visit,TIntermBinary * node)124 bool ValidateLimitationsTraverser::visitBinary(Visit, TIntermBinary *node)
125 {
126     // Check indexing.
127     switch (node->getOp())
128     {
129         case EOpIndexDirect:
130         case EOpIndexIndirect:
131             validateIndexing(node);
132             break;
133         default:
134             break;
135     }
136     return true;
137 }
138 
visitLoop(Visit,TIntermLoop * node)139 bool ValidateLimitationsTraverser::visitLoop(Visit, TIntermLoop *node)
140 {
141     if (!validateLoopType(node))
142         return false;
143 
144     if (!validateForLoopHeader(node))
145         return false;
146 
147     mLoopSymbolIds.push_back(GetLoopSymbolId(node));
148     node->getBody()->traverse(this);
149     mLoopSymbolIds.pop_back();
150 
151     // The loop is fully processed - no need to visit children.
152     return false;
153 }
154 
error(TSourceLoc loc,const char * reason,const char * token)155 void ValidateLimitationsTraverser::error(TSourceLoc loc, const char *reason, const char *token)
156 {
157     mDiagnostics->error(loc, reason, token);
158 }
159 
error(TSourceLoc loc,const char * reason,const ImmutableString & token)160 void ValidateLimitationsTraverser::error(TSourceLoc loc,
161                                          const char *reason,
162                                          const ImmutableString &token)
163 {
164     error(loc, reason, token.data());
165 }
166 
isLoopIndex(TIntermSymbol * symbol)167 bool ValidateLimitationsTraverser::isLoopIndex(TIntermSymbol *symbol)
168 {
169     return std::find(mLoopSymbolIds.begin(), mLoopSymbolIds.end(), symbol->uniqueId().get()) !=
170            mLoopSymbolIds.end();
171 }
172 
validateLoopType(TIntermLoop * node)173 bool ValidateLimitationsTraverser::validateLoopType(TIntermLoop *node)
174 {
175     TLoopType type = node->getType();
176     if (type == ELoopFor)
177         return true;
178 
179     // Reject while and do-while loops.
180     error(node->getLine(), "This type of loop is not allowed", type == ELoopWhile ? "while" : "do");
181     return false;
182 }
183 
validateForLoopHeader(TIntermLoop * node)184 bool ValidateLimitationsTraverser::validateForLoopHeader(TIntermLoop *node)
185 {
186     ASSERT(node->getType() == ELoopFor);
187 
188     //
189     // The for statement has the form:
190     //    for ( init-declaration ; condition ; expression ) statement
191     //
192     int indexSymbolId = validateForLoopInit(node);
193     if (indexSymbolId < 0)
194         return false;
195     if (!validateForLoopCond(node, indexSymbolId))
196         return false;
197     if (!validateForLoopExpr(node, indexSymbolId))
198         return false;
199 
200     return true;
201 }
202 
validateForLoopInit(TIntermLoop * node)203 int ValidateLimitationsTraverser::validateForLoopInit(TIntermLoop *node)
204 {
205     TIntermNode *init = node->getInit();
206     if (init == nullptr)
207     {
208         error(node->getLine(), "Missing init declaration", "for");
209         return -1;
210     }
211 
212     //
213     // init-declaration has the form:
214     //     type-specifier identifier = constant-expression
215     //
216     TIntermDeclaration *decl = init->getAsDeclarationNode();
217     if (decl == nullptr)
218     {
219         error(init->getLine(), "Invalid init declaration", "for");
220         return -1;
221     }
222     // To keep things simple do not allow declaration list.
223     TIntermSequence *declSeq = decl->getSequence();
224     if (declSeq->size() != 1)
225     {
226         error(decl->getLine(), "Invalid init declaration", "for");
227         return -1;
228     }
229     TIntermBinary *declInit = (*declSeq)[0]->getAsBinaryNode();
230     if ((declInit == nullptr) || (declInit->getOp() != EOpInitialize))
231     {
232         error(decl->getLine(), "Invalid init declaration", "for");
233         return -1;
234     }
235     TIntermSymbol *symbol = declInit->getLeft()->getAsSymbolNode();
236     if (symbol == nullptr)
237     {
238         error(declInit->getLine(), "Invalid init declaration", "for");
239         return -1;
240     }
241     // The loop index has type int or float.
242     TBasicType type = symbol->getBasicType();
243     if ((type != EbtInt) && (type != EbtUInt) && (type != EbtFloat))
244     {
245         error(symbol->getLine(), "Invalid type for loop index", getBasicString(type));
246         return -1;
247     }
248     // The loop index is initialized with constant expression.
249     if (!isConstExpr(declInit->getRight()))
250     {
251         error(declInit->getLine(), "Loop index cannot be initialized with non-constant expression",
252               symbol->getName());
253         return -1;
254     }
255 
256     return symbol->uniqueId().get();
257 }
258 
validateForLoopCond(TIntermLoop * node,int indexSymbolId)259 bool ValidateLimitationsTraverser::validateForLoopCond(TIntermLoop *node, int indexSymbolId)
260 {
261     TIntermNode *cond = node->getCondition();
262     if (cond == nullptr)
263     {
264         error(node->getLine(), "Missing condition", "for");
265         return false;
266     }
267     //
268     // condition has the form:
269     //     loop_index relational_operator constant_expression
270     //
271     TIntermBinary *binOp = cond->getAsBinaryNode();
272     if (binOp == nullptr)
273     {
274         error(node->getLine(), "Invalid condition", "for");
275         return false;
276     }
277     // Loop index should be to the left of relational operator.
278     TIntermSymbol *symbol = binOp->getLeft()->getAsSymbolNode();
279     if (symbol == nullptr)
280     {
281         error(binOp->getLine(), "Invalid condition", "for");
282         return false;
283     }
284     if (symbol->uniqueId().get() != indexSymbolId)
285     {
286         error(symbol->getLine(), "Expected loop index", symbol->getName());
287         return false;
288     }
289     // Relational operator is one of: > >= < <= == or !=.
290     switch (binOp->getOp())
291     {
292         case EOpEqual:
293         case EOpNotEqual:
294         case EOpLessThan:
295         case EOpGreaterThan:
296         case EOpLessThanEqual:
297         case EOpGreaterThanEqual:
298             break;
299         default:
300             error(binOp->getLine(), "Invalid relational operator",
301                   GetOperatorString(binOp->getOp()));
302             break;
303     }
304     // Loop index must be compared with a constant.
305     if (!isConstExpr(binOp->getRight()))
306     {
307         error(binOp->getLine(), "Loop index cannot be compared with non-constant expression",
308               symbol->getName());
309         return false;
310     }
311 
312     return true;
313 }
314 
validateForLoopExpr(TIntermLoop * node,int indexSymbolId)315 bool ValidateLimitationsTraverser::validateForLoopExpr(TIntermLoop *node, int indexSymbolId)
316 {
317     TIntermNode *expr = node->getExpression();
318     if (expr == nullptr)
319     {
320         error(node->getLine(), "Missing expression", "for");
321         return false;
322     }
323 
324     // for expression has one of the following forms:
325     //     loop_index++
326     //     loop_index--
327     //     loop_index += constant_expression
328     //     loop_index -= constant_expression
329     //     ++loop_index
330     //     --loop_index
331     // The last two forms are not specified in the spec, but I am assuming
332     // its an oversight.
333     TIntermUnary *unOp   = expr->getAsUnaryNode();
334     TIntermBinary *binOp = unOp ? nullptr : expr->getAsBinaryNode();
335 
336     TOperator op            = EOpNull;
337     const TFunction *opFunc = nullptr;
338     TIntermSymbol *symbol   = nullptr;
339     if (unOp != nullptr)
340     {
341         op     = unOp->getOp();
342         opFunc = unOp->getFunction();
343         symbol = unOp->getOperand()->getAsSymbolNode();
344     }
345     else if (binOp != nullptr)
346     {
347         op     = binOp->getOp();
348         symbol = binOp->getLeft()->getAsSymbolNode();
349     }
350 
351     // The operand must be loop index.
352     if (symbol == nullptr)
353     {
354         error(expr->getLine(), "Invalid expression", "for");
355         return false;
356     }
357     if (symbol->uniqueId().get() != indexSymbolId)
358     {
359         error(symbol->getLine(), "Expected loop index", symbol->getName());
360         return false;
361     }
362 
363     // The operator is one of: ++ -- += -=.
364     switch (op)
365     {
366         case EOpPostIncrement:
367         case EOpPostDecrement:
368         case EOpPreIncrement:
369         case EOpPreDecrement:
370             ASSERT((unOp != nullptr) && (binOp == nullptr));
371             break;
372         case EOpAddAssign:
373         case EOpSubAssign:
374             ASSERT((unOp == nullptr) && (binOp != nullptr));
375             break;
376         default:
377             if (BuiltInGroup::IsBuiltIn(op))
378             {
379                 ASSERT(opFunc != nullptr);
380                 error(expr->getLine(), "Invalid built-in call", opFunc->name().data());
381             }
382             else
383             {
384                 error(expr->getLine(), "Invalid operator", GetOperatorString(op));
385             }
386             return false;
387     }
388 
389     // Loop index must be incremented/decremented with a constant.
390     if (binOp != nullptr)
391     {
392         if (!isConstExpr(binOp->getRight()))
393         {
394             error(binOp->getLine(), "Loop index cannot be modified by non-constant expression",
395                   symbol->getName());
396             return false;
397         }
398     }
399 
400     return true;
401 }
402 
isConstExpr(TIntermNode * node)403 bool ValidateLimitationsTraverser::isConstExpr(TIntermNode *node)
404 {
405     ASSERT(node != nullptr);
406     return node->getAsConstantUnion() != nullptr && node->getAsTyped()->getQualifier() == EvqConst;
407 }
408 
isConstIndexExpr(TIntermNode * node)409 bool ValidateLimitationsTraverser::isConstIndexExpr(TIntermNode *node)
410 {
411     ASSERT(node != nullptr);
412 
413     ValidateConstIndexExpr validate(mLoopSymbolIds);
414     node->traverse(&validate);
415     return validate.isValid();
416 }
417 
validateIndexing(TIntermBinary * node)418 bool ValidateLimitationsTraverser::validateIndexing(TIntermBinary *node)
419 {
420     ASSERT((node->getOp() == EOpIndexDirect) || (node->getOp() == EOpIndexIndirect));
421 
422     bool valid          = true;
423     TIntermTyped *index = node->getRight();
424     // The index expession must be a constant-index-expression unless
425     // the operand is a uniform in a vertex shader.
426     TIntermTyped *operand = node->getLeft();
427     bool skip = (mShaderType == GL_VERTEX_SHADER) && (operand->getQualifier() == EvqUniform);
428     if (!skip && !isConstIndexExpr(index))
429     {
430         error(index->getLine(), "Index expression must be constant", "[]");
431         valid = false;
432     }
433     return valid;
434 }
435 
436 }  // namespace
437 
ValidateLimitations(TIntermNode * root,GLenum shaderType,TSymbolTable * symbolTable,TDiagnostics * diagnostics)438 bool ValidateLimitations(TIntermNode *root,
439                          GLenum shaderType,
440                          TSymbolTable *symbolTable,
441                          TDiagnostics *diagnostics)
442 {
443     ValidateLimitationsTraverser validate(shaderType, symbolTable, diagnostics);
444     root->traverse(&validate);
445     return diagnostics->numErrors() == 0;
446 }
447 
448 }  // namespace sh
449