xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/EmulateMultiDrawShaderBuiltins.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateGLDrawID is an AST traverser to convert the gl_DrawID builtin
7 // to a uniform int
8 //
9 // EmulateGLBaseVertex is an AST traverser to convert the gl_BaseVertex builtin
10 // to a uniform int
11 //
12 // EmulateGLBaseInstance is an AST traverser to convert the gl_BaseInstance builtin
13 // to a uniform int
14 //
15 
16 #include "compiler/translator/tree_ops/EmulateMultiDrawShaderBuiltins.h"
17 
18 #include "angle_gl.h"
19 #include "compiler/translator/StaticType.h"
20 #include "compiler/translator/Symbol.h"
21 #include "compiler/translator/SymbolTable.h"
22 #include "compiler/translator/tree_util/BuiltIn.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 #include "compiler/translator/tree_util/ReplaceVariable.h"
25 #include "compiler/translator/util.h"
26 
27 namespace sh
28 {
29 
30 namespace
31 {
32 
33 constexpr const ImmutableString kEmulatedGLDrawIDName("angle_DrawID");
34 
35 class FindGLDrawIDTraverser : public TIntermTraverser
36 {
37   public:
FindGLDrawIDTraverser()38     FindGLDrawIDTraverser() : TIntermTraverser(true, false, false), mVariable(nullptr) {}
39 
getGLDrawIDBuiltinVariable()40     const TVariable *getGLDrawIDBuiltinVariable() { return mVariable; }
41 
42   protected:
visitSymbol(TIntermSymbol * node)43     void visitSymbol(TIntermSymbol *node) override
44     {
45         if (&node->variable() == BuiltInVariable::gl_DrawID())
46         {
47             mVariable = &node->variable();
48         }
49     }
50 
51   private:
52     const TVariable *mVariable;
53 };
54 
55 class AddBaseVertexToGLVertexIDTraverser : public TIntermTraverser
56 {
57   public:
AddBaseVertexToGLVertexIDTraverser()58     AddBaseVertexToGLVertexIDTraverser() : TIntermTraverser(true, false, false) {}
59 
60   protected:
visitSymbol(TIntermSymbol * node)61     void visitSymbol(TIntermSymbol *node) override
62     {
63         if (&node->variable() == BuiltInVariable::gl_VertexID())
64         {
65 
66             TIntermSymbol *baseVertexRef = new TIntermSymbol(BuiltInVariable::gl_BaseVertex());
67 
68             TIntermBinary *addBaseVertex = new TIntermBinary(EOpAdd, node, baseVertexRef);
69             queueReplacement(addBaseVertex, OriginalNode::BECOMES_CHILD);
70         }
71     }
72 };
73 
74 constexpr const ImmutableString kEmulatedGLBaseVertexName("angle_BaseVertex");
75 
76 class FindGLBaseVertexTraverser : public TIntermTraverser
77 {
78   public:
FindGLBaseVertexTraverser()79     FindGLBaseVertexTraverser() : TIntermTraverser(true, false, false), mVariable(nullptr) {}
80 
getGLBaseVertexBuiltinVariable()81     const TVariable *getGLBaseVertexBuiltinVariable() { return mVariable; }
82 
83   protected:
visitSymbol(TIntermSymbol * node)84     void visitSymbol(TIntermSymbol *node) override
85     {
86         if (&node->variable() == BuiltInVariable::gl_BaseVertex())
87         {
88             mVariable = &node->variable();
89         }
90     }
91 
92   private:
93     const TVariable *mVariable;
94 };
95 
96 constexpr const ImmutableString kEmulatedGLBaseInstanceName("angle_BaseInstance");
97 
98 class FindGLBaseInstanceTraverser : public TIntermTraverser
99 {
100   public:
FindGLBaseInstanceTraverser()101     FindGLBaseInstanceTraverser() : TIntermTraverser(true, false, false), mVariable(nullptr) {}
102 
getGLBaseInstanceBuiltinVariable()103     const TVariable *getGLBaseInstanceBuiltinVariable() { return mVariable; }
104 
105   protected:
visitSymbol(TIntermSymbol * node)106     void visitSymbol(TIntermSymbol *node) override
107     {
108         if (&node->variable() == BuiltInVariable::gl_BaseInstance())
109         {
110             mVariable = &node->variable();
111         }
112     }
113 
114   private:
115     const TVariable *mVariable;
116 };
117 
118 }  // namespace
119 
EmulateGLDrawID(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,std::vector<sh::ShaderVariable> * uniforms)120 bool EmulateGLDrawID(TCompiler *compiler,
121                      TIntermBlock *root,
122                      TSymbolTable *symbolTable,
123                      std::vector<sh::ShaderVariable> *uniforms)
124 {
125     FindGLDrawIDTraverser traverser;
126     root->traverse(&traverser);
127     const TVariable *builtInVariable = traverser.getGLDrawIDBuiltinVariable();
128     if (builtInVariable)
129     {
130         const TType *type = StaticType::Get<EbtInt, EbpHigh, EvqUniform, 1, 1>();
131         const TVariable *drawID =
132             new TVariable(symbolTable, kEmulatedGLDrawIDName, type, SymbolType::AngleInternal);
133         const TIntermSymbol *drawIDSymbol = new TIntermSymbol(drawID);
134 
135         // AngleInternal variables don't get collected
136         ShaderVariable uniform;
137         uniform.name          = kEmulatedGLDrawIDName.data();
138         uniform.mappedName    = kEmulatedGLDrawIDName.data();
139         uniform.type          = GLVariableType(*type);
140         uniform.precision     = GLVariablePrecision(*type);
141         uniform.staticUse     = symbolTable->isStaticallyUsed(*builtInVariable);
142         uniform.active        = true;
143         uniform.binding       = type->getLayoutQualifier().binding;
144         uniform.location      = type->getLayoutQualifier().location;
145         uniform.offset        = type->getLayoutQualifier().offset;
146         uniform.rasterOrdered = type->getLayoutQualifier().rasterOrdered;
147         uniform.readonly      = type->getMemoryQualifier().readonly;
148         uniform.writeonly     = type->getMemoryQualifier().writeonly;
149         uniforms->push_back(uniform);
150 
151         DeclareGlobalVariable(root, drawID);
152         if (!ReplaceVariableWithTyped(compiler, root, builtInVariable, drawIDSymbol))
153         {
154             return false;
155         }
156     }
157 
158     return true;
159 }
160 
EmulateGLBaseVertexBaseInstance(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,std::vector<sh::ShaderVariable> * uniforms,bool addBaseVertexToVertexID)161 bool EmulateGLBaseVertexBaseInstance(TCompiler *compiler,
162                                      TIntermBlock *root,
163                                      TSymbolTable *symbolTable,
164                                      std::vector<sh::ShaderVariable> *uniforms,
165                                      bool addBaseVertexToVertexID)
166 {
167     bool addBaseVertex = false, addBaseInstance = false;
168     ShaderVariable uniformBaseVertex, uniformBaseInstance;
169 
170     if (addBaseVertexToVertexID)
171     {
172         // This is a workaround for Mac AMD GPU
173         // Replace gl_VertexID with (gl_VertexID + gl_BaseVertex)
174         AddBaseVertexToGLVertexIDTraverser traverserVertexID;
175         root->traverse(&traverserVertexID);
176         if (!traverserVertexID.updateTree(compiler, root))
177         {
178             return false;
179         }
180     }
181 
182     FindGLBaseVertexTraverser traverserBaseVertex;
183     root->traverse(&traverserBaseVertex);
184     const TVariable *builtInVariableBaseVertex =
185         traverserBaseVertex.getGLBaseVertexBuiltinVariable();
186 
187     if (builtInVariableBaseVertex)
188     {
189         const TVariable *baseVertex           = BuiltInVariable::angle_BaseVertex();
190         const TType &type                     = baseVertex->getType();
191         const TIntermSymbol *baseVertexSymbol = new TIntermSymbol(baseVertex);
192 
193         // AngleInternal variables don't get collected
194         uniformBaseVertex.name          = kEmulatedGLBaseVertexName.data();
195         uniformBaseVertex.mappedName    = kEmulatedGLBaseVertexName.data();
196         uniformBaseVertex.type          = GLVariableType(type);
197         uniformBaseVertex.precision     = GLVariablePrecision(type);
198         uniformBaseVertex.staticUse     = symbolTable->isStaticallyUsed(*builtInVariableBaseVertex);
199         uniformBaseVertex.active        = true;
200         uniformBaseVertex.binding       = type.getLayoutQualifier().binding;
201         uniformBaseVertex.location      = type.getLayoutQualifier().location;
202         uniformBaseVertex.offset        = type.getLayoutQualifier().offset;
203         uniformBaseVertex.rasterOrdered = type.getLayoutQualifier().rasterOrdered;
204         uniformBaseVertex.readonly      = type.getMemoryQualifier().readonly;
205         uniformBaseVertex.writeonly     = type.getMemoryQualifier().writeonly;
206         addBaseVertex                   = true;
207 
208         DeclareGlobalVariable(root, baseVertex);
209         if (!ReplaceVariableWithTyped(compiler, root, builtInVariableBaseVertex, baseVertexSymbol))
210         {
211             return false;
212         }
213     }
214 
215     FindGLBaseInstanceTraverser traverserInstance;
216     root->traverse(&traverserInstance);
217     const TVariable *builtInVariableBaseInstance =
218         traverserInstance.getGLBaseInstanceBuiltinVariable();
219 
220     if (builtInVariableBaseInstance)
221     {
222         const TVariable *baseInstance           = BuiltInVariable::angle_BaseInstance();
223         const TType &type                       = baseInstance->getType();
224         const TIntermSymbol *baseInstanceSymbol = new TIntermSymbol(baseInstance);
225 
226         // AngleInternal variables don't get collected
227         uniformBaseInstance.name       = kEmulatedGLBaseInstanceName.data();
228         uniformBaseInstance.mappedName = kEmulatedGLBaseInstanceName.data();
229         uniformBaseInstance.type       = GLVariableType(type);
230         uniformBaseInstance.precision  = GLVariablePrecision(type);
231         uniformBaseInstance.staticUse = symbolTable->isStaticallyUsed(*builtInVariableBaseInstance);
232         uniformBaseInstance.active    = true;
233         uniformBaseInstance.binding   = type.getLayoutQualifier().binding;
234         uniformBaseInstance.location  = type.getLayoutQualifier().location;
235         uniformBaseInstance.offset    = type.getLayoutQualifier().offset;
236         uniformBaseInstance.rasterOrdered = type.getLayoutQualifier().rasterOrdered;
237         uniformBaseInstance.readonly      = type.getMemoryQualifier().readonly;
238         uniformBaseInstance.writeonly     = type.getMemoryQualifier().writeonly;
239         addBaseInstance                   = true;
240 
241         DeclareGlobalVariable(root, baseInstance);
242         if (!ReplaceVariableWithTyped(compiler, root, builtInVariableBaseInstance,
243                                       baseInstanceSymbol))
244         {
245             return false;
246         }
247     }
248 
249     // Make sure the order in uniforms is the same as the traverse order
250     if (addBaseInstance)
251     {
252         uniforms->push_back(uniformBaseInstance);
253     }
254     if (addBaseVertex)
255     {
256         uniforms->push_back(uniformBaseVertex);
257     }
258 
259     return true;
260 }
261 
262 }  // namespace sh
263