1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cstring>
8 #include <unordered_map>
9 #include <unordered_set>
10
11 #include "compiler/translator/IntermRebuild.h"
12 #include "compiler/translator/msl/AstHelpers.h"
13 #include "compiler/translator/msl/DiscoverDependentFunctions.h"
14 #include "compiler/translator/msl/IdGen.h"
15 #include "compiler/translator/msl/MapSymbols.h"
16 #include "compiler/translator/msl/Pipeline.h"
17 #include "compiler/translator/msl/RewritePipelines.h"
18 #include "compiler/translator/msl/SymbolEnv.h"
19 #include "compiler/translator/msl/TranslatorMSL.h"
20 #include "compiler/translator/tree_ops/PruneNoOps.h"
21 #include "compiler/translator/tree_util/DriverUniform.h"
22 #include "compiler/translator/tree_util/FindMain.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 #include "compiler/translator/util.h"
25
26 using namespace sh;
27
28 ////////////////////////////////////////////////////////////////////////////////
29
30 namespace
31 {
32
IsVariableInvariant(const std::vector<sh::ShaderVariable> & mVars,const ImmutableString & name)33 bool IsVariableInvariant(const std::vector<sh::ShaderVariable> &mVars, const ImmutableString &name)
34 {
35 for (const auto &var : mVars)
36 {
37 if (name == var.name)
38 {
39 return var.isInvariant;
40 }
41 }
42 // TODO(kpidington): this should be UNREACHABLE() but isn't because the translator generates
43 // declarations to unused built-in variables.
44 return false;
45 }
46
47 using VariableSet = std::unordered_set<const TVariable *>;
48 using VariableList = std::vector<const TVariable *>;
49
50 ////////////////////////////////////////////////////////////////////////////////
51
52 struct PipelineStructInfo
53 {
54 VariableSet pipelineVariables;
55 PipelineScoped<TStructure> pipelineStruct;
56 const TFunction *funcOriginalToModified = nullptr;
57 const TFunction *funcModifiedToOriginal = nullptr;
58
isEmpty__anon7ec8f16e0111::PipelineStructInfo59 bool isEmpty() const
60 {
61 if (pipelineStruct.isTotallyEmpty())
62 {
63 ASSERT(pipelineVariables.empty());
64 return true;
65 }
66 else
67 {
68 ASSERT(pipelineStruct.isTotallyFull());
69 ASSERT(!pipelineVariables.empty());
70 return false;
71 }
72 }
73 };
74
75 class GeneratePipelineStruct : private TIntermRebuild
76 {
77 private:
78 const Pipeline &mPipeline;
79 SymbolEnv &mSymbolEnv;
80 const std::vector<sh::ShaderVariable> *mVariableInfos;
81 VariableList mPipelineVariableList;
82 IdGen &mIdGen;
83 PipelineStructInfo mInfo;
84
85 public:
Exec(PipelineStructInfo & out,TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)86 static bool Exec(PipelineStructInfo &out,
87 TCompiler &compiler,
88 TIntermBlock &root,
89 IdGen &idGen,
90 const Pipeline &pipeline,
91 SymbolEnv &symbolEnv,
92 const std::vector<sh::ShaderVariable> *variableInfos)
93 {
94 GeneratePipelineStruct self(compiler, idGen, pipeline, symbolEnv, variableInfos);
95 if (!self.exec(root))
96 {
97 return false;
98 }
99 out = self.mInfo;
100 return true;
101 }
102
103 private:
GeneratePipelineStruct(TCompiler & compiler,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfos)104 GeneratePipelineStruct(TCompiler &compiler,
105 IdGen &idGen,
106 const Pipeline &pipeline,
107 SymbolEnv &symbolEnv,
108 const std::vector<sh::ShaderVariable> *variableInfos)
109 : TIntermRebuild(compiler, true, true),
110 mPipeline(pipeline),
111 mSymbolEnv(symbolEnv),
112 mVariableInfos(variableInfos),
113 mIdGen(idGen)
114 {}
115
exec(TIntermBlock & root)116 bool exec(TIntermBlock &root)
117 {
118 if (!rebuildRoot(root))
119 {
120 return false;
121 }
122
123 if (mInfo.pipelineVariables.empty())
124 {
125 return true;
126 }
127
128 TIntermSequence seq;
129
130 const TStructure &pipelineStruct = [&]() -> const TStructure & {
131 if (mPipeline.globalInstanceVar)
132 {
133 return *mPipeline.globalInstanceVar->getType().getStruct();
134 }
135 else
136 {
137 return createInternalPipelineStruct(root, seq);
138 }
139 }();
140
141 ModifiedStructMachineries modifiedMachineries;
142 const bool isUBO = mPipeline.type == Pipeline::Type::UniformBuffer;
143 const bool isUniform = mPipeline.type == Pipeline::Type::UniformBuffer ||
144 mPipeline.type == Pipeline::Type::UserUniforms;
145 const bool useAttributeAliasing =
146 mPipeline.type == Pipeline::Type::VertexIn && mCompiler.supportsAttributeAliasing();
147 const bool modified = TryCreateModifiedStruct(
148 mCompiler, mSymbolEnv, mIdGen, mPipeline.externalStructModifyConfig(), pipelineStruct,
149 mPipeline.getStructTypeName(Pipeline::Variant::Modified), modifiedMachineries, isUBO,
150 !isUniform, useAttributeAliasing);
151
152 if (modified)
153 {
154 ASSERT(mPipeline.type != Pipeline::Type::Texture);
155 ASSERT(mPipeline.type == Pipeline::Type::AngleUniforms ||
156 !mPipeline.globalInstanceVar); // This shouldn't happen by construction.
157
158 auto getFunction = [](sh::TIntermFunctionDefinition *funcDecl) {
159 return funcDecl ? funcDecl->getFunction() : nullptr;
160 };
161
162 const size_t size = modifiedMachineries.size();
163 ASSERT(size > 0);
164 for (size_t i = 0; i < size; ++i)
165 {
166 const ModifiedStructMachinery &machinery = modifiedMachineries.at(i);
167 ASSERT(machinery.modifiedStruct);
168
169 seq.push_back(new TIntermDeclaration{
170 &CreateStructTypeVariable(mSymbolTable, *machinery.modifiedStruct)});
171
172 if (mPipeline.isPipelineOut())
173 {
174 ASSERT(machinery.funcOriginalToModified);
175 ASSERT(!machinery.funcModifiedToOriginal);
176 seq.push_back(machinery.funcOriginalToModified);
177 }
178 else
179 {
180 ASSERT(machinery.funcModifiedToOriginal);
181 ASSERT(!machinery.funcOriginalToModified);
182 seq.push_back(machinery.funcModifiedToOriginal);
183 }
184
185 if (i == size - 1)
186 {
187 mInfo.funcOriginalToModified = getFunction(machinery.funcOriginalToModified);
188 mInfo.funcModifiedToOriginal = getFunction(machinery.funcModifiedToOriginal);
189
190 mInfo.pipelineStruct.internal = &pipelineStruct;
191 mInfo.pipelineStruct.external =
192 modified ? machinery.modifiedStruct : &pipelineStruct;
193 }
194 }
195 }
196 else
197 {
198 mInfo.pipelineStruct.internal = &pipelineStruct;
199 mInfo.pipelineStruct.external = &pipelineStruct;
200 }
201
202 if (mPipeline.type == Pipeline::Type::FragmentOut &&
203 mCompiler.hasPixelLocalStorageUniforms() &&
204 mCompiler.getPixelLocalStorageType() == ShPixelLocalStorageType::FramebufferFetch)
205 {
206 auto &fields = *new TFieldList();
207 for (const TField *field : mInfo.pipelineStruct.external->fields())
208 {
209 if (field->type()->getQualifier() == EvqFragmentInOut)
210 {
211 fields.push_back(new TField(&CloneType(*field->type()), field->name(),
212 kNoSourceLoc, field->symbolType()));
213 }
214 }
215 TStructure *extraStruct =
216 new TStructure(&mSymbolTable, ImmutableString("LastFragmentOut"), &fields,
217 SymbolType::AngleInternal);
218 seq.push_back(
219 new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, *extraStruct)});
220 mInfo.pipelineStruct.externalExtra = extraStruct;
221 }
222
223 root.insertChildNodes(FindMainIndex(&root), seq);
224
225 return true;
226 }
227
228 private:
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)229 PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
230 {
231 return {node, VisitBits::Neither};
232 }
visitDeclarationPost(TIntermDeclaration & declNode)233 PostResult visitDeclarationPost(TIntermDeclaration &declNode) override
234 {
235 Declaration decl = ViewDeclaration(declNode);
236 const TVariable &var = decl.symbol.variable();
237 if (mPipeline.uses(var))
238 {
239 ASSERT(mInfo.pipelineVariables.find(&var) == mInfo.pipelineVariables.end());
240 mInfo.pipelineVariables.insert(&var);
241 mPipelineVariableList.push_back(&var);
242 return nullptr;
243 }
244
245 return declNode;
246 }
247
createInternalPipelineStruct(TIntermBlock & root,TIntermSequence & outDeclSeq)248 const TStructure &createInternalPipelineStruct(TIntermBlock &root, TIntermSequence &outDeclSeq)
249 {
250 auto &fields = *new TFieldList();
251
252 switch (mPipeline.type)
253 {
254 case Pipeline::Type::Texture:
255 {
256 for (const TVariable *var : mPipelineVariableList)
257 {
258 const TType &varType = var->getType();
259 const TBasicType samplerType = varType.getBasicType();
260
261 const TStructure &textureEnv = mSymbolEnv.getTextureEnv(samplerType);
262 auto *textureEnvType = new TType(&textureEnv, false);
263 if (varType.isArray())
264 {
265 textureEnvType->makeArrays(varType.getArraySizes());
266 }
267
268 fields.push_back(
269 new TField(textureEnvType, var->name(), kNoSourceLoc, var->symbolType()));
270 }
271 }
272 break;
273
274 case Pipeline::Type::Image:
275 {
276 for (const TVariable *var : mPipelineVariableList)
277 {
278 auto &type = CloneType(var->getType());
279 auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
280 fields.push_back(field);
281 }
282 }
283 break;
284
285 case Pipeline::Type::UniformBuffer:
286 {
287 for (const TVariable *var : mPipelineVariableList)
288 {
289 auto &type = CloneType(var->getType());
290 auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
291 mSymbolEnv.markAsPointer(*field, AddressSpace::Constant);
292 mSymbolEnv.markAsUBO(*field);
293 mSymbolEnv.markAsPointer(*var, AddressSpace::Constant);
294 fields.push_back(field);
295 }
296 }
297 break;
298 default:
299 {
300 for (const TVariable *var : mPipelineVariableList)
301 {
302 auto &type = CloneType(var->getType());
303 if (mVariableInfos && IsVariableInvariant(*mVariableInfos, var->name()))
304 {
305 type.setInvariant(true);
306 }
307 auto *field = new TField(&type, var->name(), kNoSourceLoc, var->symbolType());
308 fields.push_back(field);
309 }
310 }
311 break;
312 }
313
314 Name pipelineStructName = mPipeline.getStructTypeName(Pipeline::Variant::Original);
315 auto &s = *new TStructure(&mSymbolTable, pipelineStructName.rawName(), &fields,
316 pipelineStructName.symbolType());
317
318 outDeclSeq.push_back(new TIntermDeclaration{&CreateStructTypeVariable(mSymbolTable, s)});
319
320 return s;
321 }
322 };
323
324 ////////////////////////////////////////////////////////////////////////////////
325
CreatePipelineMainLocalVar(TSymbolTable & symbolTable,const Pipeline & pipeline,PipelineScoped<TStructure> pipelineStruct)326 PipelineScoped<TVariable> CreatePipelineMainLocalVar(TSymbolTable &symbolTable,
327 const Pipeline &pipeline,
328 PipelineScoped<TStructure> pipelineStruct)
329 {
330 ASSERT(pipelineStruct.isTotallyFull());
331
332 PipelineScoped<TVariable> pipelineMainLocalVar;
333
334 auto populateExternalMainLocalVar = [&]() {
335 ASSERT(!pipelineMainLocalVar.external);
336 pipelineMainLocalVar.external = &CreateInstanceVariable(
337 symbolTable, *pipelineStruct.external,
338 pipeline.getStructInstanceName(pipelineStruct.isUniform()
339 ? Pipeline::Variant::Original
340 : Pipeline::Variant::Modified));
341 };
342
343 auto populateDistinctInternalMainLocalVar = [&]() {
344 ASSERT(!pipelineMainLocalVar.internal);
345 pipelineMainLocalVar.internal =
346 &CreateInstanceVariable(symbolTable, *pipelineStruct.internal,
347 pipeline.getStructInstanceName(Pipeline::Variant::Original));
348 };
349
350 if (pipeline.type == Pipeline::Type::InstanceId)
351 {
352 populateDistinctInternalMainLocalVar();
353 }
354 else if (pipeline.alwaysRequiresLocalVariableDeclarationInMain())
355 {
356 populateExternalMainLocalVar();
357
358 if (pipelineStruct.isUniform())
359 {
360 pipelineMainLocalVar.internal = pipelineMainLocalVar.external;
361 }
362 else
363 {
364 populateDistinctInternalMainLocalVar();
365 }
366 }
367 else if (!pipelineStruct.isUniform())
368 {
369 populateDistinctInternalMainLocalVar();
370 }
371
372 return pipelineMainLocalVar;
373 }
374
375 class PipelineFunctionEnv
376 {
377 private:
378 TCompiler &mCompiler;
379 SymbolEnv &mSymbolEnv;
380 TSymbolTable &mSymbolTable;
381 IdGen &mIdGen;
382 const Pipeline &mPipeline;
383 const std::unordered_set<const TFunction *> &mPipelineFunctions;
384 const PipelineScoped<TStructure> mPipelineStruct;
385 PipelineScoped<TVariable> &mPipelineMainLocalVar;
386 size_t mFirstParamIdxInMainFn = 0;
387
388 std::unordered_map<const TFunction *, const TFunction *> mFuncMap;
389
390 public:
PipelineFunctionEnv(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar)391 PipelineFunctionEnv(TCompiler &compiler,
392 SymbolEnv &symbolEnv,
393 IdGen &idGen,
394 const Pipeline &pipeline,
395 const std::unordered_set<const TFunction *> &pipelineFunctions,
396 PipelineScoped<TStructure> pipelineStruct,
397 PipelineScoped<TVariable> &pipelineMainLocalVar)
398 : mCompiler(compiler),
399 mSymbolEnv(symbolEnv),
400 mSymbolTable(symbolEnv.symbolTable()),
401 mIdGen(idGen),
402 mPipeline(pipeline),
403 mPipelineFunctions(pipelineFunctions),
404 mPipelineStruct(pipelineStruct),
405 mPipelineMainLocalVar(pipelineMainLocalVar)
406 {}
407
isOriginalPipelineFunction(const TFunction & func) const408 bool isOriginalPipelineFunction(const TFunction &func) const
409 {
410 return mPipelineFunctions.find(&func) != mPipelineFunctions.end();
411 }
412
isUpdatedPipelineFunction(const TFunction & func) const413 bool isUpdatedPipelineFunction(const TFunction &func) const
414 {
415 auto it = mFuncMap.find(&func);
416 if (it == mFuncMap.end())
417 {
418 return false;
419 }
420 return &func == it->second;
421 }
422
getUpdatedFunction(const TFunction & func)423 const TFunction &getUpdatedFunction(const TFunction &func)
424 {
425 ASSERT(isOriginalPipelineFunction(func) || isUpdatedPipelineFunction(func));
426
427 const TFunction *newFunc;
428
429 auto it = mFuncMap.find(&func);
430 if (it == mFuncMap.end())
431 {
432 const bool isMain = func.isMain();
433 if (isMain)
434 {
435 mFirstParamIdxInMainFn = func.getParamCount();
436 }
437
438 if (isMain && mPipeline.isPipelineOut())
439 {
440 ASSERT(func.getReturnType().getBasicType() == TBasicType::EbtVoid);
441 newFunc = &CloneFunctionAndChangeReturnType(mSymbolTable, nullptr, func,
442 *mPipelineStruct.external);
443 if (mPipeline.type == Pipeline::Type::FragmentOut &&
444 mCompiler.hasPixelLocalStorageUniforms() &&
445 mCompiler.getPixelLocalStorageType() ==
446 ShPixelLocalStorageType::FramebufferFetch)
447 {
448 // Add an input argument to main() that contains the current framebuffer
449 // attachment values, for loading pixel local storage.
450 TType *type = new TType(mPipelineStruct.externalExtra, true);
451 TVariable *lastFragmentOut =
452 new TVariable(&mSymbolTable, ImmutableString("lastFragmentOut"), type,
453 SymbolType::AngleInternal);
454 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, *newFunc,
455 *lastFragmentOut);
456 // Initialize the main local variable with the current framebuffer contents.
457 mPipelineMainLocalVar.externalExtra = lastFragmentOut;
458 }
459 }
460 else if (isMain && (mPipeline.type == Pipeline::Type::InvocationVertexGlobals))
461 {
462 ASSERT(mPipelineStruct.external->fields().size() == 1);
463 ASSERT(mPipelineStruct.external->fields()[0]->type()->getQualifier() ==
464 TQualifier::EvqVertexID);
465 auto *vertexIDMetalVar =
466 new TVariable(&mSymbolTable, ImmutableString("vertexIDMetal"),
467 new TType(TBasicType::EbtUInt), SymbolType::AngleInternal);
468 newFunc = &func;
469 mPipelineMainLocalVar.external = vertexIDMetalVar;
470 }
471 else if (isMain && (mPipeline.type == Pipeline::Type::InvocationFragmentGlobals))
472 {
473 std::vector<const TVariable *> variables;
474 for (const TField *field : mPipelineStruct.external->fields())
475 {
476 variables.push_back(new TVariable(&mSymbolTable, field->name(), field->type(),
477 field->symbolType()));
478 }
479 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
480 }
481 else if (isMain && mPipeline.type == Pipeline::Type::Texture)
482 {
483 std::vector<const TVariable *> variables;
484 TranslatorMetalReflection *reflection =
485 mtl::getTranslatorMetalReflection(&mCompiler);
486 for (const TField *field : mPipelineStruct.external->fields())
487 {
488 const TStructure *textureEnv = field->type()->getStruct();
489 ASSERT(textureEnv && textureEnv->fields().size() == 2);
490 for (const TField *subfield : textureEnv->fields())
491 {
492 const Name name = mIdGen.createNewName({field->name(), subfield->name()});
493 TType &type = *new TType(*subfield->type());
494 ASSERT(!type.isArray());
495 type.makeArrays(field->type()->getArraySizes());
496 auto *var =
497 new TVariable(&mSymbolTable, name.rawName(), &type, name.symbolType());
498 variables.push_back(var);
499 reflection->addOriginalName(var->uniqueId().get(), field->name().data());
500 }
501 }
502 newFunc = &CloneFunctionAndAppendParams(mSymbolTable, nullptr, func, variables);
503 }
504 else if (isMain && mPipeline.type == Pipeline::Type::InstanceId)
505 {
506 Name instanceIdName = mPipeline.getStructInstanceName(Pipeline::Variant::Modified);
507 auto *instanceIdVar =
508 new TVariable(&mSymbolTable, instanceIdName.rawName(),
509 new TType(TBasicType::EbtUInt), instanceIdName.symbolType());
510
511 auto *baseInstanceVar =
512 new TVariable(&mSymbolTable, kBaseInstanceName.rawName(),
513 new TType(TBasicType::EbtUInt), kBaseInstanceName.symbolType());
514
515 newFunc = &CloneFunctionAndPrependTwoParams(mSymbolTable, nullptr, func,
516 *instanceIdVar, *baseInstanceVar);
517 mPipelineMainLocalVar.external = instanceIdVar;
518 mPipelineMainLocalVar.externalExtra = baseInstanceVar;
519 }
520 else if (isMain && mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
521 {
522 ASSERT(mPipelineMainLocalVar.isTotallyFull());
523 newFunc = &func;
524 }
525 else
526 {
527 const TVariable *var;
528 AddressSpace addressSpace;
529
530 if (isMain && !mPipelineMainLocalVar.isUniform())
531 {
532 var = &CreateInstanceVariable(
533 mSymbolTable, *mPipelineStruct.external,
534 mPipeline.getStructInstanceName(Pipeline::Variant::Modified));
535 addressSpace = mPipeline.externalAddressSpace();
536 }
537 else
538 {
539 var = &CreateInstanceVariable(
540 mSymbolTable, *mPipelineStruct.internal,
541 mPipeline.getStructInstanceName(Pipeline::Variant::Original));
542 addressSpace = mPipelineMainLocalVar.isUniform()
543 ? mPipeline.externalAddressSpace()
544 : AddressSpace::Thread;
545 }
546
547 bool markAsReference = true;
548 if (isMain)
549 {
550 switch (mPipeline.type)
551 {
552 case Pipeline::Type::VertexIn:
553 case Pipeline::Type::FragmentIn:
554 case Pipeline::Type::Image:
555 markAsReference = false;
556 break;
557
558 default:
559 break;
560 }
561 }
562
563 if (markAsReference)
564 {
565 mSymbolEnv.markAsReference(*var, addressSpace);
566 }
567
568 newFunc = &CloneFunctionAndPrependParam(mSymbolTable, nullptr, func, *var);
569 }
570
571 mFuncMap[&func] = newFunc;
572 mFuncMap[newFunc] = newFunc;
573 }
574 else
575 {
576 newFunc = it->second;
577 }
578
579 return *newFunc;
580 }
581
createUpdatedFunctionPrototype(TIntermFunctionPrototype & funcProtoNode)582 TIntermFunctionPrototype *createUpdatedFunctionPrototype(
583 TIntermFunctionPrototype &funcProtoNode)
584 {
585 const TFunction &func = *funcProtoNode.getFunction();
586 if (!isOriginalPipelineFunction(func) && !isUpdatedPipelineFunction(func))
587 {
588 return nullptr;
589 }
590 const TFunction &newFunc = getUpdatedFunction(func);
591 return new TIntermFunctionPrototype(&newFunc);
592 }
593
getFirstParamIdxInMainFn() const594 size_t getFirstParamIdxInMainFn() const { return mFirstParamIdxInMainFn; }
595 };
596
597 class UpdatePipelineFunctions : private TIntermRebuild
598 {
599 private:
600 const Pipeline &mPipeline;
601 const PipelineScoped<TStructure> mPipelineStruct;
602 PipelineScoped<TVariable> &mPipelineMainLocalVar;
603 SymbolEnv &mSymbolEnv;
604 PipelineFunctionEnv mEnv;
605 const TFunction *mFuncOriginalToModified;
606 const TFunction *mFuncModifiedToOriginal;
607
608 public:
ThreadPipeline(TCompiler & compiler,TIntermBlock & root,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)609 static bool ThreadPipeline(TCompiler &compiler,
610 TIntermBlock &root,
611 const Pipeline &pipeline,
612 const std::unordered_set<const TFunction *> &pipelineFunctions,
613 PipelineScoped<TStructure> pipelineStruct,
614 PipelineScoped<TVariable> &pipelineMainLocalVar,
615 IdGen &idGen,
616 SymbolEnv &symbolEnv,
617 const TFunction *funcOriginalToModified,
618 const TFunction *funcModifiedToOriginal)
619 {
620 UpdatePipelineFunctions self(compiler, pipeline, pipelineFunctions, pipelineStruct,
621 pipelineMainLocalVar, idGen, symbolEnv, funcOriginalToModified,
622 funcModifiedToOriginal);
623 if (!self.rebuildRoot(root))
624 {
625 return false;
626 }
627 return true;
628 }
629
630 private:
UpdatePipelineFunctions(TCompiler & compiler,const Pipeline & pipeline,const std::unordered_set<const TFunction * > & pipelineFunctions,PipelineScoped<TStructure> pipelineStruct,PipelineScoped<TVariable> & pipelineMainLocalVar,IdGen & idGen,SymbolEnv & symbolEnv,const TFunction * funcOriginalToModified,const TFunction * funcModifiedToOriginal)631 UpdatePipelineFunctions(TCompiler &compiler,
632 const Pipeline &pipeline,
633 const std::unordered_set<const TFunction *> &pipelineFunctions,
634 PipelineScoped<TStructure> pipelineStruct,
635 PipelineScoped<TVariable> &pipelineMainLocalVar,
636 IdGen &idGen,
637 SymbolEnv &symbolEnv,
638 const TFunction *funcOriginalToModified,
639 const TFunction *funcModifiedToOriginal)
640 : TIntermRebuild(compiler, false, true),
641 mPipeline(pipeline),
642 mPipelineStruct(pipelineStruct),
643 mPipelineMainLocalVar(pipelineMainLocalVar),
644 mSymbolEnv(symbolEnv),
645 mEnv(compiler,
646 symbolEnv,
647 idGen,
648 pipeline,
649 pipelineFunctions,
650 pipelineStruct,
651 mPipelineMainLocalVar),
652 mFuncOriginalToModified(funcOriginalToModified),
653 mFuncModifiedToOriginal(funcModifiedToOriginal)
654 {
655 ASSERT(mPipelineStruct.isTotallyFull());
656 }
657
getInternalPipelineVariable(const TFunction & pipelineFunc)658 const TVariable &getInternalPipelineVariable(const TFunction &pipelineFunc)
659 {
660 if (pipelineFunc.isMain() && (mPipeline.alwaysRequiresLocalVariableDeclarationInMain() ||
661 !mPipelineMainLocalVar.isUniform()))
662 {
663 ASSERT(mPipelineMainLocalVar.internal);
664 return *mPipelineMainLocalVar.internal;
665 }
666 else
667 {
668 ASSERT(pipelineFunc.getParamCount() > 0);
669 return *pipelineFunc.getParam(0);
670 }
671 }
672
getExternalPipelineVariable(const TFunction & mainFunc)673 const TVariable &getExternalPipelineVariable(const TFunction &mainFunc)
674 {
675 ASSERT(mainFunc.isMain());
676 if (mPipelineMainLocalVar.external)
677 {
678 return *mPipelineMainLocalVar.external;
679 }
680 else
681 {
682 ASSERT(mainFunc.getParamCount() > 0);
683 return *mainFunc.getParam(0);
684 }
685 }
686
getExternalExtraPipelineVariable(const TFunction & mainFunc)687 const TVariable &getExternalExtraPipelineVariable(const TFunction &mainFunc)
688 {
689 ASSERT(mainFunc.isMain());
690 if (mPipelineMainLocalVar.externalExtra)
691 {
692 return *mPipelineMainLocalVar.externalExtra;
693 }
694 else
695 {
696 ASSERT(mainFunc.getParamCount() > 1);
697 return *mainFunc.getParam(1);
698 }
699 }
700
visitAggregatePost(TIntermAggregate & callNode)701 PostResult visitAggregatePost(TIntermAggregate &callNode) override
702 {
703 if (callNode.isConstructor())
704 {
705 return callNode;
706 }
707 else
708 {
709 const TFunction &oldCalledFunc = *callNode.getFunction();
710 if (!mEnv.isOriginalPipelineFunction(oldCalledFunc))
711 {
712 return callNode;
713 }
714 const TFunction &newCalledFunc = mEnv.getUpdatedFunction(oldCalledFunc);
715
716 const TFunction *oldOwnerFunc = getParentFunction();
717 ASSERT(oldOwnerFunc);
718 const TFunction &newOwnerFunc = mEnv.getUpdatedFunction(*oldOwnerFunc);
719
720 return *TIntermAggregate::CreateFunctionCall(
721 newCalledFunc, &CloneSequenceAndPrepend(
722 *callNode.getSequence(),
723 *new TIntermSymbol(&getInternalPipelineVariable(newOwnerFunc))));
724 }
725 }
726
visitFunctionPrototypePost(TIntermFunctionPrototype & funcProtoNode)727 PostResult visitFunctionPrototypePost(TIntermFunctionPrototype &funcProtoNode) override
728 {
729 TIntermFunctionPrototype *newFuncProtoNode =
730 mEnv.createUpdatedFunctionPrototype(funcProtoNode);
731 if (newFuncProtoNode == nullptr)
732 {
733 return funcProtoNode;
734 }
735 return *newFuncProtoNode;
736 }
737
visitFunctionDefinitionPost(TIntermFunctionDefinition & funcDefNode)738 PostResult visitFunctionDefinitionPost(TIntermFunctionDefinition &funcDefNode) override
739 {
740 if (funcDefNode.getFunction()->isMain())
741 {
742 return visitMain(funcDefNode);
743 }
744 else
745 {
746 return visitNonMain(funcDefNode);
747 }
748 }
749
visitNonMain(TIntermFunctionDefinition & funcDefNode)750 TIntermNode &visitNonMain(TIntermFunctionDefinition &funcDefNode)
751 {
752 TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
753 ASSERT(!funcProtoNode.getFunction()->isMain());
754
755 TIntermFunctionPrototype *newFuncProtoNode =
756 mEnv.createUpdatedFunctionPrototype(funcProtoNode);
757 if (newFuncProtoNode == nullptr)
758 {
759 return funcDefNode;
760 }
761
762 const TFunction &func = *newFuncProtoNode->getFunction();
763 ASSERT(!func.isMain());
764
765 TIntermBlock *body = funcDefNode.getBody();
766
767 return *new TIntermFunctionDefinition(newFuncProtoNode, body);
768 }
769
visitMain(TIntermFunctionDefinition & funcDefNode)770 TIntermNode &visitMain(TIntermFunctionDefinition &funcDefNode)
771 {
772 TIntermFunctionPrototype &funcProtoNode = *funcDefNode.getFunctionPrototype();
773 ASSERT(funcProtoNode.getFunction()->isMain());
774
775 TIntermFunctionPrototype *newFuncProtoNode =
776 mEnv.createUpdatedFunctionPrototype(funcProtoNode);
777 if (newFuncProtoNode == nullptr)
778 {
779 return funcDefNode;
780 }
781
782 const TFunction &func = *newFuncProtoNode->getFunction();
783 ASSERT(func.isMain());
784
785 auto callModifiedToOriginal = [&](TIntermBlock &body) {
786 ASSERT(mPipelineMainLocalVar.internal);
787 if (!mPipeline.isPipelineOut())
788 {
789 ASSERT(mFuncModifiedToOriginal);
790 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
791 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
792 body.appendStatement(TIntermAggregate::CreateFunctionCall(
793 *mFuncModifiedToOriginal, new TIntermSequence{m, o}));
794 }
795 };
796
797 auto callOriginalToModified = [&](TIntermBlock &body) {
798 ASSERT(mPipelineMainLocalVar.internal);
799 if (mPipeline.isPipelineOut())
800 {
801 ASSERT(mFuncOriginalToModified);
802 auto *o = new TIntermSymbol(mPipelineMainLocalVar.internal);
803 auto *m = new TIntermSymbol(&getExternalPipelineVariable(func));
804 body.appendStatement(TIntermAggregate::CreateFunctionCall(
805 *mFuncOriginalToModified, new TIntermSequence{o, m}));
806 }
807 };
808
809 TIntermBlock *body = funcDefNode.getBody();
810
811 if (mPipeline.alwaysRequiresLocalVariableDeclarationInMain())
812 {
813 ASSERT(mPipelineMainLocalVar.isTotallyFull());
814
815 auto *newBody = new TIntermBlock();
816 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
817
818 if (mPipeline.type == Pipeline::Type::InvocationVertexGlobals)
819 {
820 ASSERT(mPipelineStruct.external->fields().size() == 1);
821 ASSERT(mPipelineStruct.external->fields()[0]->type()->getQualifier() ==
822 TQualifier::EvqVertexID);
823 const TField *field = mPipelineStruct.external->fields()[0];
824 auto *var =
825 new TVariable(&mSymbolTable, field->name(), field->type(), field->symbolType());
826 auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, Name(*var));
827 auto vertexIDMetal = new TIntermSymbol(&getExternalPipelineVariable(func));
828 auto *assignNode = new TIntermBinary(
829 TOperator::EOpAssign, &accessNode,
830 &AsType(mSymbolEnv, *new TType(TBasicType::EbtInt), *vertexIDMetal));
831 newBody->appendStatement(assignNode);
832 }
833 else if (mPipeline.type == Pipeline::Type::InvocationFragmentGlobals)
834 {
835 // Populate struct instance with references to global pipeline variables.
836 for (const TField *field : mPipelineStruct.external->fields())
837 {
838 auto *var = new TVariable(&mSymbolTable, field->name(), field->type(),
839 field->symbolType());
840 auto *symbol = new TIntermSymbol(var);
841 auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, Name(*var));
842 auto *assignNode = new TIntermBinary(TOperator::EOpAssign, &accessNode, symbol);
843 newBody->appendStatement(assignNode);
844 }
845 }
846 else if (mPipeline.type == Pipeline::Type::FragmentOut &&
847 mCompiler.hasPixelLocalStorageUniforms() &&
848 mCompiler.getPixelLocalStorageType() ==
849 ShPixelLocalStorageType::FramebufferFetch)
850 {
851 ASSERT(mPipelineMainLocalVar.externalExtra);
852 auto &lastFragmentOut = *mPipelineMainLocalVar.externalExtra;
853 for (const TField *field : lastFragmentOut.getType().getStruct()->fields())
854 {
855 auto &accessNode = AccessField(*mPipelineMainLocalVar.internal, Name(*field));
856 auto &sourceNode = AccessField(lastFragmentOut, Name(*field));
857 auto *assignNode =
858 new TIntermBinary(TOperator::EOpAssign, &accessNode, &sourceNode);
859 newBody->appendStatement(assignNode);
860 }
861 }
862 else if (mPipeline.type == Pipeline::Type::Texture)
863 {
864 const TFieldList &fields = mPipelineStruct.external->fields();
865
866 ASSERT(func.getParamCount() >= mEnv.getFirstParamIdxInMainFn() + 2 * fields.size());
867 size_t paramIndex = mEnv.getFirstParamIdxInMainFn();
868
869 for (const TField *field : fields)
870 {
871 const TVariable &textureParam = *func.getParam(paramIndex++);
872 const TVariable &samplerParam = *func.getParam(paramIndex++);
873
874 auto go = [&](TIntermTyped &env, const int *index) {
875 TIntermTyped &textureField =
876 AccessField(AccessIndex(*env.deepCopy(), index),
877 Name("texture", SymbolType::BuiltIn));
878 TIntermTyped &samplerField =
879 AccessField(AccessIndex(*env.deepCopy(), index),
880 Name("sampler", SymbolType::BuiltIn));
881
882 auto mkAssign = [&](TIntermTyped &field, const TVariable ¶m) {
883 return new TIntermBinary(TOperator::EOpAssign, &field,
884 &mSymbolEnv.callFunctionOverload(
885 Name("addressof"), field.getType(),
886 *new TIntermSequence{&AccessIndex(
887 *new TIntermSymbol(¶m), index)}));
888 };
889
890 newBody->appendStatement(mkAssign(textureField, textureParam));
891 newBody->appendStatement(mkAssign(samplerField, samplerParam));
892 };
893
894 TIntermTyped &env = AccessField(*mPipelineMainLocalVar.internal, Name(*field));
895 const TType &envType = env.getType();
896
897 if (envType.isArray())
898 {
899 ASSERT(!envType.isArrayOfArrays());
900 const auto n = static_cast<int>(envType.getArraySizeProduct());
901 for (int i = 0; i < n; ++i)
902 {
903 go(env, &i);
904 }
905 }
906 else
907 {
908 go(env, nullptr);
909 }
910 }
911 }
912 else if (mPipeline.type == Pipeline::Type::InstanceId)
913 {
914 auto varInstanceId = new TIntermSymbol(&getExternalPipelineVariable(func));
915 auto varBaseInstance = new TIntermSymbol(&getExternalExtraPipelineVariable(func));
916
917 newBody->appendStatement(new TIntermBinary(
918 TOperator::EOpAssign,
919 &AccessFieldByIndex(*new TIntermSymbol(&getInternalPipelineVariable(func)), 0),
920 &AsType(
921 mSymbolEnv, *new TType(TBasicType::EbtInt),
922 *new TIntermBinary(TOperator::EOpSub, varInstanceId, varBaseInstance))));
923 }
924 else if (!mPipelineMainLocalVar.isUniform())
925 {
926 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.external});
927 callModifiedToOriginal(*newBody);
928 }
929
930 newBody->appendStatement(body);
931
932 if (!mPipelineMainLocalVar.isUniform())
933 {
934 callOriginalToModified(*newBody);
935 }
936
937 if (mPipeline.isPipelineOut())
938 {
939 newBody->appendStatement(new TIntermBranch(
940 TOperator::EOpReturn, new TIntermSymbol(mPipelineMainLocalVar.external)));
941 }
942
943 body = newBody;
944 }
945 else if (!mPipelineMainLocalVar.isUniform())
946 {
947 ASSERT(!mPipelineMainLocalVar.external);
948 ASSERT(mPipelineMainLocalVar.internal);
949
950 auto *newBody = new TIntermBlock();
951 newBody->appendStatement(new TIntermDeclaration{mPipelineMainLocalVar.internal});
952 callModifiedToOriginal(*newBody);
953 newBody->appendStatement(body);
954 callOriginalToModified(*newBody);
955 body = newBody;
956 }
957
958 return *new TIntermFunctionDefinition(newFuncProtoNode, body);
959 }
960 };
961
962 ////////////////////////////////////////////////////////////////////////////////
963
UpdatePipelineSymbols(Pipeline::Type pipelineType,TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,const VariableSet & pipelineVariables,PipelineScoped<TVariable> pipelineMainLocalVar)964 bool UpdatePipelineSymbols(Pipeline::Type pipelineType,
965 TCompiler &compiler,
966 TIntermBlock &root,
967 SymbolEnv &symbolEnv,
968 const VariableSet &pipelineVariables,
969 PipelineScoped<TVariable> pipelineMainLocalVar)
970 {
971 auto map = [&](const TFunction *owner, TIntermSymbol &symbol) -> TIntermNode & {
972 if (!owner)
973 return symbol;
974 const TVariable &var = symbol.variable();
975 if (pipelineVariables.find(&var) == pipelineVariables.end())
976 {
977 return symbol;
978 }
979 const TVariable *structInstanceVar;
980 if (owner->isMain() && pipelineType != Pipeline::Type::FragmentIn)
981 {
982 ASSERT(pipelineMainLocalVar.internal);
983 structInstanceVar = pipelineMainLocalVar.internal;
984 }
985 else
986 {
987 ASSERT(owner->getParamCount() > 0);
988 structInstanceVar = owner->getParam(0);
989 }
990 ASSERT(structInstanceVar);
991 return AccessField(*structInstanceVar, Name(var));
992 };
993 return MapSymbols(compiler, root, map);
994 }
995
996 ////////////////////////////////////////////////////////////////////////////////
997
RewritePipeline(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const Pipeline & pipeline,SymbolEnv & symbolEnv,const std::vector<sh::ShaderVariable> * variableInfo,PipelineScoped<TStructure> & outStruct)998 bool RewritePipeline(TCompiler &compiler,
999 TIntermBlock &root,
1000 IdGen &idGen,
1001 const Pipeline &pipeline,
1002 SymbolEnv &symbolEnv,
1003 const std::vector<sh::ShaderVariable> *variableInfo,
1004 PipelineScoped<TStructure> &outStruct)
1005 {
1006 ASSERT(outStruct.isTotallyEmpty());
1007
1008 TSymbolTable &symbolTable = compiler.getSymbolTable();
1009
1010 PipelineStructInfo psi;
1011 if (!GeneratePipelineStruct::Exec(psi, compiler, root, idGen, pipeline, symbolEnv,
1012 variableInfo))
1013 {
1014 return false;
1015 }
1016
1017 if (psi.isEmpty())
1018 {
1019 return true;
1020 }
1021
1022 const auto pipelineFunctions = DiscoverDependentFunctions(root, [&](const TVariable &var) {
1023 return psi.pipelineVariables.find(&var) != psi.pipelineVariables.end();
1024 });
1025
1026 auto pipelineMainLocalVar =
1027 CreatePipelineMainLocalVar(symbolTable, pipeline, psi.pipelineStruct);
1028
1029 if (!UpdatePipelineFunctions::ThreadPipeline(
1030 compiler, root, pipeline, pipelineFunctions, psi.pipelineStruct, pipelineMainLocalVar,
1031 idGen, symbolEnv, psi.funcOriginalToModified, psi.funcModifiedToOriginal))
1032 {
1033 return false;
1034 }
1035
1036 if (!pipeline.globalInstanceVar)
1037 {
1038 if (!UpdatePipelineSymbols(pipeline.type, compiler, root, symbolEnv, psi.pipelineVariables,
1039 pipelineMainLocalVar))
1040 {
1041 return false;
1042 }
1043 }
1044
1045 if (!PruneNoOps(&compiler, &root, &compiler.getSymbolTable()))
1046 {
1047 return false;
1048 }
1049
1050 outStruct = psi.pipelineStruct;
1051 return true;
1052 }
1053
1054 } // anonymous namespace
1055
RewritePipelines(TCompiler & compiler,TIntermBlock & root,const std::vector<sh::ShaderVariable> & inputVaryings,const std::vector<sh::ShaderVariable> & outputVaryings,IdGen & idGen,DriverUniform & angleUniformsGlobalInstanceVar,SymbolEnv & symbolEnv,PipelineStructs & outStructs)1056 bool sh::RewritePipelines(TCompiler &compiler,
1057 TIntermBlock &root,
1058 const std::vector<sh::ShaderVariable> &inputVaryings,
1059 const std::vector<sh::ShaderVariable> &outputVaryings,
1060 IdGen &idGen,
1061 DriverUniform &angleUniformsGlobalInstanceVar,
1062 SymbolEnv &symbolEnv,
1063 PipelineStructs &outStructs)
1064 {
1065 struct Info
1066 {
1067 Pipeline::Type pipelineType;
1068 PipelineScoped<TStructure> &outStruct;
1069 const TVariable *globalInstanceVar;
1070 const std::vector<sh::ShaderVariable> *variableInfo;
1071 };
1072
1073 Info infos[] = {
1074 {Pipeline::Type::InstanceId, outStructs.instanceId, nullptr, nullptr},
1075 {Pipeline::Type::Texture, outStructs.image, nullptr, nullptr},
1076 {Pipeline::Type::Image, outStructs.texture, nullptr, nullptr},
1077 {Pipeline::Type::NonConstantGlobals, outStructs.nonConstantGlobals, nullptr, nullptr},
1078 {Pipeline::Type::AngleUniforms, outStructs.angleUniforms,
1079 angleUniformsGlobalInstanceVar.getDriverUniformsVariable(), nullptr},
1080 {Pipeline::Type::UserUniforms, outStructs.userUniforms, nullptr, nullptr},
1081 {Pipeline::Type::VertexIn, outStructs.vertexIn, nullptr, &inputVaryings},
1082 {Pipeline::Type::VertexOut, outStructs.vertexOut, nullptr, &outputVaryings},
1083 {Pipeline::Type::FragmentIn, outStructs.fragmentIn, nullptr, &inputVaryings},
1084 {Pipeline::Type::FragmentOut, outStructs.fragmentOut, nullptr, &outputVaryings},
1085 {Pipeline::Type::InvocationVertexGlobals, outStructs.invocationVertexGlobals, nullptr,
1086 nullptr},
1087 {Pipeline::Type::InvocationFragmentGlobals, outStructs.invocationFragmentGlobals, nullptr,
1088 &inputVaryings},
1089 {Pipeline::Type::UniformBuffer, outStructs.uniformBuffers, nullptr, nullptr},
1090 };
1091
1092 for (Info &info : infos)
1093 {
1094 if ((compiler.getShaderType() != GL_VERTEX_SHADER &&
1095 (info.pipelineType == Pipeline::Type::VertexIn ||
1096 info.pipelineType == Pipeline::Type::VertexOut ||
1097 info.pipelineType == Pipeline::Type::InvocationVertexGlobals)) ||
1098 (compiler.getShaderType() != GL_FRAGMENT_SHADER &&
1099 (info.pipelineType == Pipeline::Type::FragmentIn ||
1100 info.pipelineType == Pipeline::Type::FragmentOut ||
1101 info.pipelineType == Pipeline::Type::InvocationFragmentGlobals)))
1102 continue;
1103
1104 Pipeline pipeline{info.pipelineType, info.globalInstanceVar};
1105 if (!RewritePipeline(compiler, root, idGen, pipeline, symbolEnv, info.variableInfo,
1106 info.outStruct))
1107 {
1108 return false;
1109 }
1110 }
1111
1112 return true;
1113 }
1114