1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_ops/msl/RewriteUnaddressableReferences.h"
8 #include "compiler/translator/AsNode.h"
9 #include "compiler/translator/IntermRebuild.h"
10 #include "compiler/translator/msl/AstHelpers.h"
11
12 using namespace sh;
13
14 namespace
15 {
16
IsOutParam(const TType & paramType)17 bool IsOutParam(const TType ¶mType)
18 {
19 const TQualifier qual = paramType.getQualifier();
20 switch (qual)
21 {
22 case TQualifier::EvqParamInOut:
23 case TQualifier::EvqParamOut:
24 return true;
25
26 default:
27 return false;
28 }
29 }
30
IsVectorAccess(TIntermBinary & binary)31 bool IsVectorAccess(TIntermBinary &binary)
32 {
33 TOperator op = binary.getOp();
34 switch (op)
35 {
36 case TOperator::EOpIndexDirect:
37 case TOperator::EOpIndexIndirect:
38 break;
39
40 default:
41 return false;
42 }
43
44 const TType &leftType = binary.getLeft()->getType();
45 if (!leftType.isVector() || leftType.isArray())
46 {
47 return false;
48 }
49
50 ASSERT(IsScalarBasicType(binary.getType()));
51
52 return true;
53 }
54
IsVectorAccess(TIntermNode & node)55 bool IsVectorAccess(TIntermNode &node)
56 {
57 if (auto *bin = node.getAsBinaryNode())
58 {
59 return IsVectorAccess(*bin);
60 }
61 return false;
62 }
63
64 // Differs from IsAssignment in that it does not include (++) or (--).
IsAssignEqualsSign(TOperator op)65 bool IsAssignEqualsSign(TOperator op)
66 {
67 switch (op)
68 {
69 case TOperator::EOpAssign:
70 case TOperator::EOpInitialize:
71 case TOperator::EOpAddAssign:
72 case TOperator::EOpSubAssign:
73 case TOperator::EOpMulAssign:
74 case TOperator::EOpVectorTimesMatrixAssign:
75 case TOperator::EOpVectorTimesScalarAssign:
76 case TOperator::EOpMatrixTimesScalarAssign:
77 case TOperator::EOpMatrixTimesMatrixAssign:
78 case TOperator::EOpDivAssign:
79 case TOperator::EOpIModAssign:
80 case TOperator::EOpBitShiftLeftAssign:
81 case TOperator::EOpBitShiftRightAssign:
82 case TOperator::EOpBitwiseAndAssign:
83 case TOperator::EOpBitwiseXorAssign:
84 case TOperator::EOpBitwiseOrAssign:
85 return true;
86
87 default:
88 return false;
89 }
90 }
91
92 // Only includes ($=) style assigns, where ($) is a binary op.
IsCompoundAssign(TOperator op)93 bool IsCompoundAssign(TOperator op)
94 {
95 switch (op)
96 {
97 case TOperator::EOpAddAssign:
98 case TOperator::EOpSubAssign:
99 case TOperator::EOpMulAssign:
100 case TOperator::EOpVectorTimesMatrixAssign:
101 case TOperator::EOpVectorTimesScalarAssign:
102 case TOperator::EOpMatrixTimesScalarAssign:
103 case TOperator::EOpMatrixTimesMatrixAssign:
104 case TOperator::EOpDivAssign:
105 case TOperator::EOpIModAssign:
106 case TOperator::EOpBitShiftLeftAssign:
107 case TOperator::EOpBitShiftRightAssign:
108 case TOperator::EOpBitwiseAndAssign:
109 case TOperator::EOpBitwiseXorAssign:
110 case TOperator::EOpBitwiseOrAssign:
111 return true;
112
113 default:
114 return false;
115 }
116 }
117
ReturnsReference(TOperator op)118 bool ReturnsReference(TOperator op)
119 {
120 switch (op)
121 {
122 case TOperator::EOpAssign:
123 case TOperator::EOpInitialize:
124 case TOperator::EOpAddAssign:
125 case TOperator::EOpSubAssign:
126 case TOperator::EOpMulAssign:
127 case TOperator::EOpVectorTimesMatrixAssign:
128 case TOperator::EOpVectorTimesScalarAssign:
129 case TOperator::EOpMatrixTimesScalarAssign:
130 case TOperator::EOpMatrixTimesMatrixAssign:
131 case TOperator::EOpDivAssign:
132 case TOperator::EOpIModAssign:
133 case TOperator::EOpBitShiftLeftAssign:
134 case TOperator::EOpBitShiftRightAssign:
135 case TOperator::EOpBitwiseAndAssign:
136 case TOperator::EOpBitwiseXorAssign:
137 case TOperator::EOpBitwiseOrAssign:
138
139 case TOperator::EOpPostIncrement:
140 case TOperator::EOpPostDecrement:
141 case TOperator::EOpPreIncrement:
142 case TOperator::EOpPreDecrement:
143
144 case TOperator::EOpIndexDirect:
145 case TOperator::EOpIndexIndirect:
146 case TOperator::EOpIndexDirectStruct:
147 case TOperator::EOpIndexDirectInterfaceBlock:
148
149 return true;
150
151 default:
152 return false;
153 }
154 }
155
DecomposeCompoundAssignment(TIntermBinary & node)156 TIntermTyped &DecomposeCompoundAssignment(TIntermBinary &node)
157 {
158 TOperator op = node.getOp();
159 switch (op)
160 {
161 case TOperator::EOpAddAssign:
162 op = TOperator::EOpAdd;
163 break;
164 case TOperator::EOpSubAssign:
165 op = TOperator::EOpSub;
166 break;
167 case TOperator::EOpMulAssign:
168 op = TOperator::EOpMul;
169 break;
170 case TOperator::EOpVectorTimesMatrixAssign:
171 op = TOperator::EOpVectorTimesMatrix;
172 break;
173 case TOperator::EOpVectorTimesScalarAssign:
174 op = TOperator::EOpVectorTimesScalar;
175 break;
176 case TOperator::EOpMatrixTimesScalarAssign:
177 op = TOperator::EOpMatrixTimesScalar;
178 break;
179 case TOperator::EOpMatrixTimesMatrixAssign:
180 op = TOperator::EOpMatrixTimesMatrix;
181 break;
182 case TOperator::EOpDivAssign:
183 op = TOperator::EOpDiv;
184 break;
185 case TOperator::EOpIModAssign:
186 op = TOperator::EOpIMod;
187 break;
188 case TOperator::EOpBitShiftLeftAssign:
189 op = TOperator::EOpBitShiftLeft;
190 break;
191 case TOperator::EOpBitShiftRightAssign:
192 op = TOperator::EOpBitShiftRight;
193 break;
194 case TOperator::EOpBitwiseAndAssign:
195 op = TOperator::EOpBitwiseAnd;
196 break;
197 case TOperator::EOpBitwiseXorAssign:
198 op = TOperator::EOpBitwiseXor;
199 break;
200 case TOperator::EOpBitwiseOrAssign:
201 op = TOperator::EOpBitwiseOr;
202 break;
203 default:
204 UNREACHABLE();
205 }
206
207 // This assumes SeparateCompoundExpressions has already been called.
208 // This assumption allows this code to not need to introduce temporaries.
209 //
210 // e.g. dont have to worry about:
211 // vec[hasSideEffect()] *= 4
212 // becoming
213 // vec[hasSideEffect()] = vec[hasSideEffect()] * 4
214
215 TIntermTyped *left = node.getLeft();
216 TIntermTyped *right = node.getRight();
217 return *new TIntermBinary(TOperator::EOpAssign, left->deepCopy(),
218 new TIntermBinary(op, left, right));
219 }
220
221 class Rewriter1 : public TIntermRebuild
222 {
223 public:
Rewriter1(TCompiler & compiler)224 Rewriter1(TCompiler &compiler) : TIntermRebuild(compiler, false, true) {}
225
visitBinaryPost(TIntermBinary & binaryNode)226 PostResult visitBinaryPost(TIntermBinary &binaryNode) override
227 {
228 const TOperator op = binaryNode.getOp();
229 if (IsCompoundAssign(op))
230 {
231 TIntermTyped &left = *binaryNode.getLeft();
232 if (left.getAsSwizzleNode() || IsVectorAccess(left))
233 {
234 return DecomposeCompoundAssignment(binaryNode);
235 }
236 }
237 return binaryNode;
238 }
239 };
240
241 class Rewriter2 : public TIntermRebuild
242 {
243 std::vector<bool> mRequiresAddressingStack;
244 SymbolEnv &mSymbolEnv;
245
246 private:
requiresAddressing() const247 bool requiresAddressing() const
248 {
249 if (mRequiresAddressingStack.empty())
250 {
251 return false;
252 }
253 return mRequiresAddressingStack.back();
254 }
255
256 public:
~Rewriter2()257 ~Rewriter2() override { ASSERT(mRequiresAddressingStack.empty()); }
258
Rewriter2(TCompiler & compiler,SymbolEnv & symbolEnv)259 Rewriter2(TCompiler &compiler, SymbolEnv &symbolEnv)
260 : TIntermRebuild(compiler, true, true), mSymbolEnv(symbolEnv)
261 {}
262
visitAggregatePre(TIntermAggregate & aggregateNode)263 PreResult visitAggregatePre(TIntermAggregate &aggregateNode) override
264 {
265 const TFunction *func = aggregateNode.getFunction();
266 if (!func)
267 {
268 return aggregateNode;
269 }
270
271 TIntermSequence &args = *aggregateNode.getSequence();
272 size_t argCount = args.size();
273
274 for (size_t i = 0; i < argCount; ++i)
275 {
276 const TVariable ¶m = *func->getParam(i);
277 const TType ¶mType = param.getType();
278 TIntermNode *arg = args[i];
279 ASSERT(arg);
280
281 mRequiresAddressingStack.push_back(IsOutParam(paramType));
282 args[i] = rebuild(*arg).single();
283 ASSERT(args[i]);
284 ASSERT(!mRequiresAddressingStack.empty());
285 mRequiresAddressingStack.pop_back();
286 }
287
288 return {aggregateNode, VisitBits::Neither};
289 }
290
visitSwizzlePost(TIntermSwizzle & swizzleNode)291 PostResult visitSwizzlePost(TIntermSwizzle &swizzleNode) override
292 {
293 if (!requiresAddressing())
294 {
295 return swizzleNode;
296 }
297
298 TIntermTyped &vecNode = *swizzleNode.getOperand();
299 const TQualifierList &offsets = swizzleNode.getSwizzleOffsets();
300 ASSERT(!offsets.empty());
301 ASSERT(offsets.size() <= 4);
302
303 auto &args = *new TIntermSequence();
304 args.reserve(offsets.size() + 1);
305 args.push_back(&vecNode);
306 for (int offset : offsets)
307 {
308 args.push_back(new TIntermConstantUnion(new TConstantUnion(offset),
309 *new TType(TBasicType::EbtInt)));
310 }
311
312 return mSymbolEnv.callFunctionOverload(Name("swizzle_ref"), swizzleNode.getType(), args);
313 }
314
visitBinaryPre(TIntermBinary & binaryNode)315 PreResult visitBinaryPre(TIntermBinary &binaryNode) override
316 {
317 const TOperator op = binaryNode.getOp();
318
319 const bool isAccess = IsVectorAccess(binaryNode);
320
321 const bool disableTop = !ReturnsReference(op) || !requiresAddressing();
322 const bool disableLeft = disableTop;
323 const bool disableRight = disableTop || isAccess || IsAssignEqualsSign(op);
324
325 auto traverse = [&](TIntermTyped &node, const bool disable) -> TIntermTyped & {
326 if (disable)
327 {
328 mRequiresAddressingStack.push_back(false);
329 }
330 auto *newNode = asNode<TIntermTyped>(rebuild(node).single());
331 ASSERT(newNode);
332 if (disable)
333 {
334 mRequiresAddressingStack.pop_back();
335 }
336 return *newNode;
337 };
338
339 TIntermTyped &leftNode = *binaryNode.getLeft();
340 TIntermTyped &rightNode = *binaryNode.getRight();
341
342 TIntermTyped &newLeft = traverse(leftNode, disableLeft);
343 TIntermTyped &newRight = traverse(rightNode, disableRight);
344
345 if (!isAccess || disableTop)
346 {
347 if (&leftNode == &newLeft && &rightNode == &newRight)
348 {
349 return {&binaryNode, VisitBits::Neither};
350 }
351 return {*new TIntermBinary(op, &newLeft, &newRight), VisitBits::Neither};
352 }
353
354 return {mSymbolEnv.callFunctionOverload(Name("elem_ref"), binaryNode.getType(),
355 *new TIntermSequence{&newLeft, &newRight}),
356 VisitBits::Neither};
357 }
358 };
359
360 } // anonymous namespace
361
RewriteUnaddressableReferences(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv)362 bool sh::RewriteUnaddressableReferences(TCompiler &compiler,
363 TIntermBlock &root,
364 SymbolEnv &symbolEnv)
365 {
366 if (!Rewriter1(compiler).rebuildRoot(root))
367 {
368 return false;
369 }
370 if (!Rewriter2(compiler, symbolEnv).rebuildRoot(root))
371 {
372 return false;
373 }
374 return true;
375 }
376