xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/spirv/ClampGLLayer.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // ClampGLLayer: Clamp gl_Layer to 0 if framebuffer is not layered.
7 //
8 
9 #include "compiler/translator/tree_ops/spirv/ClampGLLayer.h"
10 
11 #include "compiler/translator/StaticType.h"
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/DriverUniform.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 
17 namespace sh
18 {
19 namespace
20 {
21 // A traverser to check if gl_Layer is used at all.
22 class HasGLLayerTraverser : public TIntermTraverser
23 {
24   public:
HasGLLayerTraverser(TSymbolTable * symbolTable)25     HasGLLayerTraverser(TSymbolTable *symbolTable)
26         : TIntermTraverser(true, false, false, symbolTable)
27     {}
28 
referencesGLLayer() const29     bool referencesGLLayer() const { return mReferencesGLLayer; }
30 
visitSymbol(TIntermSymbol * symbol)31     void visitSymbol(TIntermSymbol *symbol) override
32     {
33         if (symbol->getQualifier() == EvqLayerOut)
34         {
35             mReferencesGLLayer = true;
36         }
37     }
38 
39   private:
40     bool mReferencesGLLayer = false;
41 };
42 
43 // A traverser that adds `if (!layeredFramebuffer) gl_Layer = 0;` before emitVertex() in geometry
44 // shaders.
45 class ClampGLLayerTraverser : public TIntermTraverser
46 {
47   public:
ClampGLLayerTraverser(TSymbolTable * symbolTable,const DriverUniform * driverUniforms,int shaderVersion)48     ClampGLLayerTraverser(TSymbolTable *symbolTable,
49                           const DriverUniform *driverUniforms,
50                           int shaderVersion)
51         : TIntermTraverser(true, false, false, symbolTable),
52           mDriverUniforms(driverUniforms),
53           mShaderVersion(shaderVersion)
54     {}
55 
56     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
57 
58   private:
59     const DriverUniform *mDriverUniforms;
60     int mShaderVersion;
61 };
62 
visitAggregate(Visit visit,TIntermAggregate * node)63 bool ClampGLLayerTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
64 {
65     ASSERT(visit == Visit::PreVisit);
66 
67     if (node->getOp() != EOpEmitVertex)
68     {
69         return false;
70     }
71 
72     // if (!layeredFramebuffer)
73     TIntermTyped *layeredFramebuffer =
74         new TIntermUnary(EOpLogicalNot, mDriverUniforms->getLayeredFramebuffer(), nullptr);
75 
76     // gl_Layer = 0;
77     const TVariable *gl_Layer = static_cast<const TVariable *>(
78         mSymbolTable->findBuiltIn(ImmutableString("gl_Layer"), mShaderVersion));
79     TIntermBinary *setToZero =
80         new TIntermBinary(EOpAssign, new TIntermSymbol(gl_Layer), CreateIndexNode(0));
81 
82     TIntermBlock *block = new TIntermBlock;
83     block->appendStatement(setToZero);
84 
85     TIntermIfElse *ifNotLayered = new TIntermIfElse(layeredFramebuffer, block, nullptr);
86 
87     TIntermSequence replacement;
88     replacement.push_back(ifNotLayered);
89     replacement.push_back(node);
90     mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node, std::move(replacement));
91 
92     return false;
93 }
94 }  // anonymous namespace
95 
ClampGLLayer(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const DriverUniform * driverUniforms)96 bool ClampGLLayer(TCompiler *compiler,
97                   TIntermBlock *root,
98                   TSymbolTable *symbolTable,
99                   const DriverUniform *driverUniforms)
100 {
101     // First, check if there is a reference to gl_Layer.  If there isn't, there's nothing to do.
102     // Note that if gl_Layer isn't otherwise set, this transformation adds static usage of it
103     // without initializaing it in every path, leading to multiple drivers crashing / failing tests.
104     HasGLLayerTraverser hasGLLayer(symbolTable);
105     root->traverse(&hasGLLayer);
106     if (!hasGLLayer.referencesGLLayer())
107     {
108         return true;
109     }
110 
111     ClampGLLayerTraverser traverser(symbolTable, driverUniforms, compiler->getShaderVersion());
112     root->traverse(&traverser);
113     return traverser.updateTree(compiler, root);
114 }
115 }  // namespace sh
116