1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // ClampGLLayer: Clamp gl_Layer to 0 if framebuffer is not layered.
7 //
8
9 #include "compiler/translator/tree_ops/spirv/ClampGLLayer.h"
10
11 #include "compiler/translator/StaticType.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/DriverUniform.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16
17 namespace sh
18 {
19 namespace
20 {
21 // A traverser to check if gl_Layer is used at all.
22 class HasGLLayerTraverser : public TIntermTraverser
23 {
24 public:
HasGLLayerTraverser(TSymbolTable * symbolTable)25 HasGLLayerTraverser(TSymbolTable *symbolTable)
26 : TIntermTraverser(true, false, false, symbolTable)
27 {}
28
referencesGLLayer() const29 bool referencesGLLayer() const { return mReferencesGLLayer; }
30
visitSymbol(TIntermSymbol * symbol)31 void visitSymbol(TIntermSymbol *symbol) override
32 {
33 if (symbol->getQualifier() == EvqLayerOut)
34 {
35 mReferencesGLLayer = true;
36 }
37 }
38
39 private:
40 bool mReferencesGLLayer = false;
41 };
42
43 // A traverser that adds `if (!layeredFramebuffer) gl_Layer = 0;` before emitVertex() in geometry
44 // shaders.
45 class ClampGLLayerTraverser : public TIntermTraverser
46 {
47 public:
ClampGLLayerTraverser(TSymbolTable * symbolTable,const DriverUniform * driverUniforms,int shaderVersion)48 ClampGLLayerTraverser(TSymbolTable *symbolTable,
49 const DriverUniform *driverUniforms,
50 int shaderVersion)
51 : TIntermTraverser(true, false, false, symbolTable),
52 mDriverUniforms(driverUniforms),
53 mShaderVersion(shaderVersion)
54 {}
55
56 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
57
58 private:
59 const DriverUniform *mDriverUniforms;
60 int mShaderVersion;
61 };
62
visitAggregate(Visit visit,TIntermAggregate * node)63 bool ClampGLLayerTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
64 {
65 ASSERT(visit == Visit::PreVisit);
66
67 if (node->getOp() != EOpEmitVertex)
68 {
69 return false;
70 }
71
72 // if (!layeredFramebuffer)
73 TIntermTyped *layeredFramebuffer =
74 new TIntermUnary(EOpLogicalNot, mDriverUniforms->getLayeredFramebuffer(), nullptr);
75
76 // gl_Layer = 0;
77 const TVariable *gl_Layer = static_cast<const TVariable *>(
78 mSymbolTable->findBuiltIn(ImmutableString("gl_Layer"), mShaderVersion));
79 TIntermBinary *setToZero =
80 new TIntermBinary(EOpAssign, new TIntermSymbol(gl_Layer), CreateIndexNode(0));
81
82 TIntermBlock *block = new TIntermBlock;
83 block->appendStatement(setToZero);
84
85 TIntermIfElse *ifNotLayered = new TIntermIfElse(layeredFramebuffer, block, nullptr);
86
87 TIntermSequence replacement;
88 replacement.push_back(ifNotLayered);
89 replacement.push_back(node);
90 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, std::move(replacement));
91
92 return false;
93 }
94 } // anonymous namespace
95
ClampGLLayer(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const DriverUniform * driverUniforms)96 bool ClampGLLayer(TCompiler *compiler,
97 TIntermBlock *root,
98 TSymbolTable *symbolTable,
99 const DriverUniform *driverUniforms)
100 {
101 // First, check if there is a reference to gl_Layer. If there isn't, there's nothing to do.
102 // Note that if gl_Layer isn't otherwise set, this transformation adds static usage of it
103 // without initializaing it in every path, leading to multiple drivers crashing / failing tests.
104 HasGLLayerTraverser hasGLLayer(symbolTable);
105 root->traverse(&hasGLLayer);
106 if (!hasGLLayer.referencesGLLayer())
107 {
108 return true;
109 }
110
111 ClampGLLayerTraverser traverser(symbolTable, driverUniforms, compiler->getShaderVersion());
112 root->traverse(&traverser);
113 return traverser.updateTree(compiler, root);
114 }
115 } // namespace sh
116