1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Implementation of InterpolateAtOffset viewport transformation.
7 // See header for more info.
8
9 #include "compiler/translator/tree_ops/spirv/RewriteInterpolateAtOffset.h"
10
11 #include "common/angleutils.h"
12 #include "compiler/translator/StaticType.h"
13 #include "compiler/translator/SymbolTable.h"
14 #include "compiler/translator/spirv/TranslatorSPIRV.h"
15 #include "compiler/translator/tree_util/DriverUniform.h"
16 #include "compiler/translator/tree_util/IntermNode_util.h"
17 #include "compiler/translator/tree_util/IntermTraverse.h"
18 #include "compiler/translator/tree_util/SpecializationConstant.h"
19
20 namespace sh
21 {
22
23 namespace
24 {
25
26 class Traverser : public TIntermTraverser
27 {
28 public:
29 Traverser(TSymbolTable *symbolTable, SpecConst *specConst, const DriverUniform *driverUniforms);
30
31 bool update(TCompiler *compiler, TIntermBlock *root);
32
33 private:
34 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
35
36 const TFunction *getRotateFunc();
37
38 SpecConst *mSpecConst = nullptr;
39 const DriverUniform *mDriverUniforms = nullptr;
40
41 TIntermFunctionDefinition *mRotateFunc = nullptr;
42 };
43
Traverser(TSymbolTable * symbolTable,SpecConst * specConst,const DriverUniform * driverUniforms)44 Traverser::Traverser(TSymbolTable *symbolTable,
45 SpecConst *specConst,
46 const DriverUniform *driverUniforms)
47 : TIntermTraverser(true, false, false, symbolTable),
48 mSpecConst(specConst),
49 mDriverUniforms(driverUniforms)
50 {}
51
update(TCompiler * compiler,TIntermBlock * root)52 bool Traverser::update(TCompiler *compiler, TIntermBlock *root)
53 {
54 if (mRotateFunc != nullptr)
55 {
56 const size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(root);
57 root->insertStatement(firstFunctionIndex, mRotateFunc);
58 }
59
60 return updateTree(compiler, root);
61 }
62
visitAggregate(Visit visit,TIntermAggregate * node)63 bool Traverser::visitAggregate(Visit visit, TIntermAggregate *node)
64 {
65 // Decide if the node represents the call of texelFetchOffset.
66 if (!BuiltInGroup::IsBuiltIn(node->getOp()))
67 {
68 return true;
69 }
70
71 ASSERT(node->getFunction()->symbolType() == SymbolType::BuiltIn);
72 if (node->getFunction()->name() != "interpolateAtOffset")
73 {
74 return true;
75 }
76
77 const TIntermSequence *sequence = node->getSequence();
78 ASSERT(sequence->size() == 2u);
79
80 // offset
81 TIntermTyped *offsetNode = sequence->at(1)->getAsTyped();
82 ASSERT(offsetNode->getType().getBasicType() == EbtFloat &&
83 offsetNode->getType().getNominalSize() == 2);
84
85 // Rotate the offset as necessary.
86 const TFunction *rotateFunc = getRotateFunc();
87
88 TIntermSequence args = {
89 offsetNode,
90 };
91 TIntermTyped *correctedOffset = TIntermAggregate::CreateFunctionCall(*rotateFunc, &args);
92 correctedOffset->setLine(offsetNode->getLine());
93
94 // Replace the offset by the rotated one.
95 queueReplacementWithParent(node, offsetNode, correctedOffset, OriginalNode::IS_DROPPED);
96
97 return true;
98 }
99
getRotateFunc()100 const TFunction *Traverser::getRotateFunc()
101 {
102 if (mRotateFunc != nullptr)
103 {
104 return mRotateFunc->getFunction();
105 }
106
107 // The function prototype is vec2 ANGLERotateInterpolateOffset(vec2 offset)
108 const TType *vec2Type = StaticType::GetBasic<EbtFloat, EbpMedium, 2>();
109
110 TType *offsetType = new TType(*vec2Type);
111 offsetType->setQualifier(EvqParamIn);
112
113 TVariable *offsetParam = new TVariable(mSymbolTable, ImmutableString("offset"), offsetType,
114 SymbolType::AngleInternal);
115
116 TFunction *function =
117 new TFunction(mSymbolTable, ImmutableString("ANGLERotateInterpolateOffset"),
118 SymbolType::AngleInternal, vec2Type, true);
119 function->addParameter(offsetParam);
120
121 // The function body is as such:
122 //
123 // return (swap ? offset.yx : offset) * flip;
124
125 TIntermTyped *swapXY = mSpecConst->getSwapXY();
126 if (swapXY == nullptr)
127 {
128 swapXY = mDriverUniforms->getSwapXY();
129 }
130
131 TIntermTyped *flipXY = mDriverUniforms->getFlipXY(mSymbolTable, DriverUniformFlip::Fragment);
132
133 TIntermSwizzle *offsetYX = new TIntermSwizzle(new TIntermSymbol(offsetParam), {1, 0});
134
135 TIntermTyped *swapped = new TIntermTernary(swapXY, offsetYX, new TIntermSymbol(offsetParam));
136 TIntermTyped *flipped = new TIntermBinary(EOpMul, swapped, flipXY);
137 TIntermBranch *returnStatement = new TIntermBranch(EOpReturn, flipped);
138
139 TIntermBlock *body = new TIntermBlock;
140 body->appendStatement(returnStatement);
141
142 mRotateFunc = new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
143 return function;
144 }
145
146 } // anonymous namespace
147
RewriteInterpolateAtOffset(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,int shaderVersion,SpecConst * specConst,const DriverUniform * driverUniforms)148 bool RewriteInterpolateAtOffset(TCompiler *compiler,
149 TIntermBlock *root,
150 TSymbolTable *symbolTable,
151 int shaderVersion,
152 SpecConst *specConst,
153 const DriverUniform *driverUniforms)
154 {
155 // interpolateAtOffset is only valid in GLSL 3.0 and later.
156 if (shaderVersion < 300)
157 {
158 return true;
159 }
160
161 Traverser traverser(symbolTable, specConst, driverUniforms);
162 root->traverse(&traverser);
163 return traverser.update(compiler, root);
164 }
165
166 } // namespace sh
167