1 //
2 // Copyright 2019 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateGLDrawID is an AST traverser to convert the gl_DrawID builtin
7 // to a uniform int
8 //
9 // EmulateGLBaseVertex is an AST traverser to convert the gl_BaseVertex builtin
10 // to a uniform int
11 //
12 // EmulateGLBaseInstance is an AST traverser to convert the gl_BaseInstance builtin
13 // to a uniform int
14 //
15
16 #include "compiler/translator/tree_ops/EmulateMultiDrawShaderBuiltins.h"
17
18 #include "angle_gl.h"
19 #include "compiler/translator/StaticType.h"
20 #include "compiler/translator/Symbol.h"
21 #include "compiler/translator/SymbolTable.h"
22 #include "compiler/translator/tree_util/BuiltIn.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 #include "compiler/translator/tree_util/ReplaceVariable.h"
25 #include "compiler/translator/util.h"
26
27 namespace sh
28 {
29
30 namespace
31 {
32
33 constexpr const ImmutableString kEmulatedGLDrawIDName("angle_DrawID");
34
35 class FindGLDrawIDTraverser : public TIntermTraverser
36 {
37 public:
FindGLDrawIDTraverser()38 FindGLDrawIDTraverser() : TIntermTraverser(true, false, false), mVariable(nullptr) {}
39
getGLDrawIDBuiltinVariable()40 const TVariable *getGLDrawIDBuiltinVariable() { return mVariable; }
41
42 protected:
visitSymbol(TIntermSymbol * node)43 void visitSymbol(TIntermSymbol *node) override
44 {
45 if (&node->variable() == BuiltInVariable::gl_DrawID())
46 {
47 mVariable = &node->variable();
48 }
49 }
50
51 private:
52 const TVariable *mVariable;
53 };
54
55 class AddBaseVertexToGLVertexIDTraverser : public TIntermTraverser
56 {
57 public:
AddBaseVertexToGLVertexIDTraverser()58 AddBaseVertexToGLVertexIDTraverser() : TIntermTraverser(true, false, false) {}
59
60 protected:
visitSymbol(TIntermSymbol * node)61 void visitSymbol(TIntermSymbol *node) override
62 {
63 if (&node->variable() == BuiltInVariable::gl_VertexID())
64 {
65
66 TIntermSymbol *baseVertexRef = new TIntermSymbol(BuiltInVariable::gl_BaseVertex());
67
68 TIntermBinary *addBaseVertex = new TIntermBinary(EOpAdd, node, baseVertexRef);
69 queueReplacement(addBaseVertex, OriginalNode::BECOMES_CHILD);
70 }
71 }
72 };
73
74 constexpr const ImmutableString kEmulatedGLBaseVertexName("angle_BaseVertex");
75
76 class FindGLBaseVertexTraverser : public TIntermTraverser
77 {
78 public:
FindGLBaseVertexTraverser()79 FindGLBaseVertexTraverser() : TIntermTraverser(true, false, false), mVariable(nullptr) {}
80
getGLBaseVertexBuiltinVariable()81 const TVariable *getGLBaseVertexBuiltinVariable() { return mVariable; }
82
83 protected:
visitSymbol(TIntermSymbol * node)84 void visitSymbol(TIntermSymbol *node) override
85 {
86 if (&node->variable() == BuiltInVariable::gl_BaseVertex())
87 {
88 mVariable = &node->variable();
89 }
90 }
91
92 private:
93 const TVariable *mVariable;
94 };
95
96 constexpr const ImmutableString kEmulatedGLBaseInstanceName("angle_BaseInstance");
97
98 class FindGLBaseInstanceTraverser : public TIntermTraverser
99 {
100 public:
FindGLBaseInstanceTraverser()101 FindGLBaseInstanceTraverser() : TIntermTraverser(true, false, false), mVariable(nullptr) {}
102
getGLBaseInstanceBuiltinVariable()103 const TVariable *getGLBaseInstanceBuiltinVariable() { return mVariable; }
104
105 protected:
visitSymbol(TIntermSymbol * node)106 void visitSymbol(TIntermSymbol *node) override
107 {
108 if (&node->variable() == BuiltInVariable::gl_BaseInstance())
109 {
110 mVariable = &node->variable();
111 }
112 }
113
114 private:
115 const TVariable *mVariable;
116 };
117
118 } // namespace
119
EmulateGLDrawID(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,std::vector<sh::ShaderVariable> * uniforms)120 bool EmulateGLDrawID(TCompiler *compiler,
121 TIntermBlock *root,
122 TSymbolTable *symbolTable,
123 std::vector<sh::ShaderVariable> *uniforms)
124 {
125 FindGLDrawIDTraverser traverser;
126 root->traverse(&traverser);
127 const TVariable *builtInVariable = traverser.getGLDrawIDBuiltinVariable();
128 if (builtInVariable)
129 {
130 const TType *type = StaticType::Get<EbtInt, EbpHigh, EvqUniform, 1, 1>();
131 const TVariable *drawID =
132 new TVariable(symbolTable, kEmulatedGLDrawIDName, type, SymbolType::AngleInternal);
133 const TIntermSymbol *drawIDSymbol = new TIntermSymbol(drawID);
134
135 // AngleInternal variables don't get collected
136 ShaderVariable uniform;
137 uniform.name = kEmulatedGLDrawIDName.data();
138 uniform.mappedName = kEmulatedGLDrawIDName.data();
139 uniform.type = GLVariableType(*type);
140 uniform.precision = GLVariablePrecision(*type);
141 uniform.staticUse = symbolTable->isStaticallyUsed(*builtInVariable);
142 uniform.active = true;
143 uniform.binding = type->getLayoutQualifier().binding;
144 uniform.location = type->getLayoutQualifier().location;
145 uniform.offset = type->getLayoutQualifier().offset;
146 uniform.rasterOrdered = type->getLayoutQualifier().rasterOrdered;
147 uniform.readonly = type->getMemoryQualifier().readonly;
148 uniform.writeonly = type->getMemoryQualifier().writeonly;
149 uniforms->push_back(uniform);
150
151 DeclareGlobalVariable(root, drawID);
152 if (!ReplaceVariableWithTyped(compiler, root, builtInVariable, drawIDSymbol))
153 {
154 return false;
155 }
156 }
157
158 return true;
159 }
160
EmulateGLBaseVertexBaseInstance(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,std::vector<sh::ShaderVariable> * uniforms,bool addBaseVertexToVertexID)161 bool EmulateGLBaseVertexBaseInstance(TCompiler *compiler,
162 TIntermBlock *root,
163 TSymbolTable *symbolTable,
164 std::vector<sh::ShaderVariable> *uniforms,
165 bool addBaseVertexToVertexID)
166 {
167 bool addBaseVertex = false, addBaseInstance = false;
168 ShaderVariable uniformBaseVertex, uniformBaseInstance;
169
170 if (addBaseVertexToVertexID)
171 {
172 // This is a workaround for Mac AMD GPU
173 // Replace gl_VertexID with (gl_VertexID + gl_BaseVertex)
174 AddBaseVertexToGLVertexIDTraverser traverserVertexID;
175 root->traverse(&traverserVertexID);
176 if (!traverserVertexID.updateTree(compiler, root))
177 {
178 return false;
179 }
180 }
181
182 FindGLBaseVertexTraverser traverserBaseVertex;
183 root->traverse(&traverserBaseVertex);
184 const TVariable *builtInVariableBaseVertex =
185 traverserBaseVertex.getGLBaseVertexBuiltinVariable();
186
187 if (builtInVariableBaseVertex)
188 {
189 const TVariable *baseVertex = BuiltInVariable::angle_BaseVertex();
190 const TType &type = baseVertex->getType();
191 const TIntermSymbol *baseVertexSymbol = new TIntermSymbol(baseVertex);
192
193 // AngleInternal variables don't get collected
194 uniformBaseVertex.name = kEmulatedGLBaseVertexName.data();
195 uniformBaseVertex.mappedName = kEmulatedGLBaseVertexName.data();
196 uniformBaseVertex.type = GLVariableType(type);
197 uniformBaseVertex.precision = GLVariablePrecision(type);
198 uniformBaseVertex.staticUse = symbolTable->isStaticallyUsed(*builtInVariableBaseVertex);
199 uniformBaseVertex.active = true;
200 uniformBaseVertex.binding = type.getLayoutQualifier().binding;
201 uniformBaseVertex.location = type.getLayoutQualifier().location;
202 uniformBaseVertex.offset = type.getLayoutQualifier().offset;
203 uniformBaseVertex.rasterOrdered = type.getLayoutQualifier().rasterOrdered;
204 uniformBaseVertex.readonly = type.getMemoryQualifier().readonly;
205 uniformBaseVertex.writeonly = type.getMemoryQualifier().writeonly;
206 addBaseVertex = true;
207
208 DeclareGlobalVariable(root, baseVertex);
209 if (!ReplaceVariableWithTyped(compiler, root, builtInVariableBaseVertex, baseVertexSymbol))
210 {
211 return false;
212 }
213 }
214
215 FindGLBaseInstanceTraverser traverserInstance;
216 root->traverse(&traverserInstance);
217 const TVariable *builtInVariableBaseInstance =
218 traverserInstance.getGLBaseInstanceBuiltinVariable();
219
220 if (builtInVariableBaseInstance)
221 {
222 const TVariable *baseInstance = BuiltInVariable::angle_BaseInstance();
223 const TType &type = baseInstance->getType();
224 const TIntermSymbol *baseInstanceSymbol = new TIntermSymbol(baseInstance);
225
226 // AngleInternal variables don't get collected
227 uniformBaseInstance.name = kEmulatedGLBaseInstanceName.data();
228 uniformBaseInstance.mappedName = kEmulatedGLBaseInstanceName.data();
229 uniformBaseInstance.type = GLVariableType(type);
230 uniformBaseInstance.precision = GLVariablePrecision(type);
231 uniformBaseInstance.staticUse = symbolTable->isStaticallyUsed(*builtInVariableBaseInstance);
232 uniformBaseInstance.active = true;
233 uniformBaseInstance.binding = type.getLayoutQualifier().binding;
234 uniformBaseInstance.location = type.getLayoutQualifier().location;
235 uniformBaseInstance.offset = type.getLayoutQualifier().offset;
236 uniformBaseInstance.rasterOrdered = type.getLayoutQualifier().rasterOrdered;
237 uniformBaseInstance.readonly = type.getMemoryQualifier().readonly;
238 uniformBaseInstance.writeonly = type.getMemoryQualifier().writeonly;
239 addBaseInstance = true;
240
241 DeclareGlobalVariable(root, baseInstance);
242 if (!ReplaceVariableWithTyped(compiler, root, builtInVariableBaseInstance,
243 baseInstanceSymbol))
244 {
245 return false;
246 }
247 }
248
249 // Make sure the order in uniforms is the same as the traverse order
250 if (addBaseInstance)
251 {
252 uniforms->push_back(uniformBaseInstance);
253 }
254 if (addBaseVertex)
255 {
256 uniforms->push_back(uniformBaseVertex);
257 }
258
259 return true;
260 }
261
262 } // namespace sh
263