xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/spirv/RewriteInterpolateAtOffset.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Implementation of InterpolateAtOffset viewport transformation.
7 // See header for more info.
8 
9 #include "compiler/translator/tree_ops/spirv/RewriteInterpolateAtOffset.h"
10 
11 #include "common/angleutils.h"
12 #include "compiler/translator/StaticType.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/spirv/TranslatorSPIRV.h"
15 #include "compiler/translator/tree_util/DriverUniform.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 #include "compiler/translator/tree_util/SpecializationConstant.h"
19 
20 namespace sh
21 {
22 
23 namespace
24 {
25 
26 class Traverser : public TIntermTraverser
27 {
28   public:
29     Traverser(TSymbolTable *symbolTable, SpecConst *specConst, const DriverUniform *driverUniforms);
30 
31     bool update(TCompiler *compiler, TIntermBlock *root);
32 
33   private:
34     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
35 
36     const TFunction *getRotateFunc();
37 
38     SpecConst *mSpecConst                = nullptr;
39     const DriverUniform *mDriverUniforms = nullptr;
40 
41     TIntermFunctionDefinition *mRotateFunc = nullptr;
42 };
43 
Traverser(TSymbolTable * symbolTable,SpecConst * specConst,const DriverUniform * driverUniforms)44 Traverser::Traverser(TSymbolTable *symbolTable,
45                      SpecConst *specConst,
46                      const DriverUniform *driverUniforms)
47     : TIntermTraverser(true, false, false, symbolTable),
48       mSpecConst(specConst),
49       mDriverUniforms(driverUniforms)
50 {}
51 
update(TCompiler * compiler,TIntermBlock * root)52 bool Traverser::update(TCompiler *compiler, TIntermBlock *root)
53 {
54     if (mRotateFunc != nullptr)
55     {
56         const size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
57         root->insertStatement(firstFunctionIndex, mRotateFunc);
58     }
59 
60     return updateTree(compiler, root);
61 }
62 
visitAggregate(Visit visit,TIntermAggregate * node)63 bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
64 {
65     // Decide if the node represents the call of texelFetchOffset.
66     if (!BuiltInGroup::IsBuiltIn(node->getOp()))
67     {
68         return true;
69     }
70 
71     ASSERT(node->getFunction()->symbolType() == SymbolType::BuiltIn);
72     if (node->getFunction()->name() != "interpolateAtOffset")
73     {
74         return true;
75     }
76 
77     const TIntermSequence *sequence = node->getSequence();
78     ASSERT(sequence->size() == 2u);
79 
80     // offset
81     TIntermTyped *offsetNode = sequence->at(1)->getAsTyped();
82     ASSERT(offsetNode->getType().getBasicType() == EbtFloat &&
83            offsetNode->getType().getNominalSize() == 2);
84 
85     // Rotate the offset as necessary.
86     const TFunction *rotateFunc = getRotateFunc();
87 
88     TIntermSequence args = {
89         offsetNode,
90     };
91     TIntermTyped *correctedOffset = TIntermAggregate::CreateFunctionCall(*rotateFunc, &args);
92     correctedOffset->setLine(offsetNode->getLine());
93 
94     // Replace the offset by the rotated one.
95     queueReplacementWithParent(node, offsetNode, correctedOffset, OriginalNode::IS_DROPPED);
96 
97     return true;
98 }
99 
getRotateFunc()100 const TFunction *Traverser::getRotateFunc()
101 {
102     if (mRotateFunc != nullptr)
103     {
104         return mRotateFunc->getFunction();
105     }
106 
107     // The function prototype is vec2 ANGLERotateInterpolateOffset(vec2 offset)
108     const TType *vec2Type = StaticType::GetBasic<EbtFloat, EbpMedium, 2>();
109 
110     TType *offsetType = new TType(*vec2Type);
111     offsetType->setQualifier(EvqParamIn);
112 
113     TVariable *offsetParam = new TVariable(mSymbolTable, ImmutableString("offset"), offsetType,
114                                            SymbolType::AngleInternal);
115 
116     TFunction *function =
117         new TFunction(mSymbolTable, ImmutableString("ANGLERotateInterpolateOffset"),
118                       SymbolType::AngleInternal, vec2Type, true);
119     function->addParameter(offsetParam);
120 
121     // The function body is as such:
122     //
123     //     return (swap ? offset.yx : offset) * flip;
124 
125     TIntermTyped *swapXY = mSpecConst->getSwapXY();
126     if (swapXY == nullptr)
127     {
128         swapXY = mDriverUniforms->getSwapXY();
129     }
130 
131     TIntermTyped *flipXY = mDriverUniforms->getFlipXY(mSymbolTable, DriverUniformFlip::Fragment);
132 
133     TIntermSwizzle *offsetYX = new TIntermSwizzle(new TIntermSymbol(offsetParam), {1, 0});
134 
135     TIntermTyped *swapped = new TIntermTernary(swapXY, offsetYX, new TIntermSymbol(offsetParam));
136     TIntermTyped *flipped = new TIntermBinary(EOpMul, swapped, flipXY);
137     TIntermBranch *returnStatement = new TIntermBranch(EOpReturn, flipped);
138 
139     TIntermBlock *body = new TIntermBlock;
140     body->appendStatement(returnStatement);
141 
142     mRotateFunc = new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
143     return function;
144 }
145 
146 }  // anonymous namespace
147 
RewriteInterpolateAtOffset(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,int shaderVersion,SpecConst * specConst,const DriverUniform * driverUniforms)148 bool RewriteInterpolateAtOffset(TCompiler *compiler,
149                                 TIntermBlock *root,
150                                 TSymbolTable *symbolTable,
151                                 int shaderVersion,
152                                 SpecConst *specConst,
153                                 const DriverUniform *driverUniforms)
154 {
155     // interpolateAtOffset is only valid in GLSL 3.0 and later.
156     if (shaderVersion < 300)
157     {
158         return true;
159     }
160 
161     Traverser traverser(symbolTable, specConst, driverUniforms);
162     root->traverse(&traverser);
163     return traverser.update(compiler, root);
164 }
165 
166 }  // namespace sh
167