1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_ops/msl/ConvertUnsupportedConstructorsToFunctionCalls.h"
8
9 #include "compiler/translator/ImmutableString.h"
10 #include "compiler/translator/IntermRebuild.h"
11 #include "compiler/translator/Symbol.h"
12 #include "compiler/translator/tree_util/FindFunction.h"
13 #include "compiler/translator/tree_util/IntermNode_util.h"
14
15 using namespace sh;
16
17 namespace
18 {
19
AppendMatrixElementArgument(TIntermSymbol * parameter,int colIndex,int rowIndex,TIntermSequence * returnCtorArgs)20 void AppendMatrixElementArgument(TIntermSymbol *parameter,
21 int colIndex,
22 int rowIndex,
23 TIntermSequence *returnCtorArgs)
24 {
25 TIntermBinary *matColN =
26 new TIntermBinary(EOpIndexDirect, parameter->deepCopy(), CreateIndexNode(colIndex));
27 TIntermSwizzle *matElem = new TIntermSwizzle(matColN, {rowIndex});
28 returnCtorArgs->push_back(matElem);
29 }
30
31 // Adds the argument to sequence for a scalar constructor.
32 // Given scalar(scalarA) appends scalarA
33 // Given scalar(vecA) appends vecA.x
34 // Given scalar(matA) appends matA[0].x
AppendScalarFromNonScalarArguments(TFunction & function,TIntermSequence * returnCtorArgs)35 void AppendScalarFromNonScalarArguments(TFunction &function, TIntermSequence *returnCtorArgs)
36 {
37 const TVariable *var = function.getParam(0);
38 TIntermSymbol *arg0 = new TIntermSymbol(var);
39
40 const TType &type = arg0->getType();
41
42 if (type.isScalar())
43 {
44 returnCtorArgs->push_back(arg0);
45 }
46 else if (type.isVector())
47 {
48 TIntermSwizzle *vecX = new TIntermSwizzle(arg0, {0});
49 returnCtorArgs->push_back(vecX);
50 }
51 else if (type.isMatrix())
52 {
53 AppendMatrixElementArgument(arg0, 0, 0, returnCtorArgs);
54 }
55 }
56
57 // Adds the arguments to sequence for a vector constructor from a scalar.
58 // Given vecN(scalarA) appends scalarA, scalarA, ... n times
AppendVectorFromScalarArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)59 void AppendVectorFromScalarArgument(const TType &type,
60 TFunction &function,
61 TIntermSequence *returnCtorArgs)
62 {
63 const uint8_t vectorSize = type.getNominalSize();
64 const TVariable *var = function.getParam(0);
65 TIntermSymbol *v = new TIntermSymbol(var);
66 for (uint8_t i = 0; i < vectorSize; ++i)
67 {
68 returnCtorArgs->push_back(v->deepCopy());
69 }
70 }
71
72 // Adds the arguments to sequence for a vector or matrix constructor from the available arguments
73 // applying arguments in order until the requested number of values have been extracted from the
74 // given arguments or until there are no more arguments.
AppendValuesFromMultipleArguments(int numValuesNeeded,TFunction & function,TIntermSequence * returnCtorArgs)75 void AppendValuesFromMultipleArguments(int numValuesNeeded,
76 TFunction &function,
77 TIntermSequence *returnCtorArgs)
78 {
79 size_t numParameters = function.getParamCount();
80 size_t paramIndex = 0;
81 uint8_t colIndex = 0;
82 uint8_t rowIndex = 0;
83
84 for (int i = 0; i < numValuesNeeded && paramIndex < numParameters; ++i)
85 {
86 const TVariable *p = function.getParam(paramIndex);
87 TIntermSymbol *parameter = new TIntermSymbol(p);
88 if (parameter->isScalar())
89 {
90 returnCtorArgs->push_back(parameter);
91 ++paramIndex;
92 }
93 else if (parameter->isVector())
94 {
95 TIntermSwizzle *vecS = new TIntermSwizzle(parameter->deepCopy(), {rowIndex++});
96 returnCtorArgs->push_back(vecS);
97 if (rowIndex == parameter->getNominalSize())
98 {
99 ++paramIndex;
100 rowIndex = 0;
101 }
102 }
103 else if (parameter->isMatrix())
104 {
105 AppendMatrixElementArgument(parameter, colIndex, rowIndex++, returnCtorArgs);
106 if (rowIndex == parameter->getSecondarySize())
107 {
108 rowIndex = 0;
109 ++colIndex;
110 if (colIndex == parameter->getNominalSize())
111 {
112 colIndex = 0;
113 ++paramIndex;
114 }
115 }
116 }
117 }
118 }
119
120 // Adds the arguments for a matrix constructor from a scalar
121 // putting the scalar along the diagonal and 0 everywhere else.
AppendMatrixFromScalarArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)122 void AppendMatrixFromScalarArgument(const TType &type,
123 TFunction &function,
124 TIntermSequence *returnCtorArgs)
125 {
126 const TVariable *var = function.getParam(0);
127 TIntermSymbol *v = new TIntermSymbol(var);
128 const uint8_t numCols = type.getNominalSize();
129 const uint8_t numRows = type.getSecondarySize();
130 for (uint8_t col = 0; col < numCols; ++col)
131 {
132 for (uint8_t row = 0; row < numRows; ++row)
133 {
134 if (col == row)
135 {
136 returnCtorArgs->push_back(v->deepCopy());
137 }
138 else
139 {
140 returnCtorArgs->push_back(CreateFloatNode(0.0f, sh::EbpUndefined));
141 }
142 }
143 }
144 }
145
146 // Add the argument for a matrix constructor from a matrix
147 // copying elements from the same column/row and otherwise
148 // initialize to the identity matrix.
AppendMatrixFromMatrixArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)149 void AppendMatrixFromMatrixArgument(const TType &type,
150 TFunction &function,
151 TIntermSequence *returnCtorArgs)
152 {
153 const TVariable *var = function.getParam(0);
154 TIntermSymbol *v = new TIntermSymbol(var);
155 const uint8_t dstCols = type.getNominalSize();
156 const uint8_t dstRows = type.getSecondarySize();
157 const uint8_t srcCols = v->getNominalSize();
158 const uint8_t srcRows = v->getSecondarySize();
159 for (uint8_t dstCol = 0; dstCol < dstCols; ++dstCol)
160 {
161 for (uint8_t dstRow = 0; dstRow < dstRows; ++dstRow)
162 {
163 if (dstRow < srcRows && dstCol < srcCols)
164 {
165 AppendMatrixElementArgument(v, dstCol, dstRow, returnCtorArgs);
166 }
167 else
168 {
169 returnCtorArgs->push_back(
170 CreateFloatNode(dstRow == dstCol ? 1.0f : 0.0f, sh::EbpUndefined));
171 }
172 }
173 }
174 }
175
176 class Rebuild : public TIntermRebuild
177 {
178 public:
Rebuild(TCompiler & compiler)179 explicit Rebuild(TCompiler &compiler) : TIntermRebuild(compiler, false, true) {}
visitAggregatePost(TIntermAggregate & node)180 PostResult visitAggregatePost(TIntermAggregate &node) override
181 {
182 if (!node.isConstructor())
183 {
184 return node;
185 }
186
187 TIntermSequence &arguments = *node.getSequence();
188 if (arguments.empty())
189 {
190 return node;
191 }
192
193 const TType &type = node.getType();
194 const TType &arg0Type = arguments[0]->getAsTyped()->getType();
195
196 if (!type.isScalar() && !type.isVector() && !type.isMatrix())
197 {
198 return node;
199 }
200
201 if (type.isArray())
202 {
203 return node;
204 }
205
206 // check for type_ctor(sameType)
207 // scalar(scalar) -> passthrough
208 // vecN(vecN) -> passthrough
209 // matN(matN) -> passthrough
210 if (arguments.size() == 1 && arg0Type == type)
211 {
212 return node;
213 }
214
215 // The following are simple casts:
216 //
217 // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
218 // - gvecN(vN) (where the argument is a single vector with the same number of components).
219 // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions). Note
220 // that
221 // matrices are always float, so there's no actual cast and this would be a no-op.
222 //
223 const bool isSingleScalarCast =
224 arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
225 const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
226 arg0Type.isVector() &&
227 type.getNominalSize() == arg0Type.getNominalSize();
228 const bool isSingleMatrixCast =
229 arguments.size() == 1 && type.isMatrix() && arg0Type.isMatrix() &&
230 type.getCols() == arg0Type.getCols() && type.getRows() == arg0Type.getRows();
231 if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
232 {
233 return node;
234 }
235
236 // Cases we need to handle:
237 // scalar(vec)
238 // scalar(mat)
239 // vecN(scalar)
240 // vecN(vecM)
241 // vecN(a,...)
242 // matN(scalar) -> diag
243 // matN(vec) -> fail!
244 // manN(matM) -> corner + ident
245 // matN(a, ...)
246
247 // Build a function and pass all the constructor's arguments to it.
248 TIntermBlock *body = new TIntermBlock;
249 TFunction *function = new TFunction(&mSymbolTable, ImmutableString(""),
250 SymbolType::AngleInternal, &type, true);
251
252 for (size_t i = 0; i < arguments.size(); ++i)
253 {
254 TIntermTyped &arg = *arguments[i]->getAsTyped();
255 TType *argType = new TType(arg.getBasicType(), arg.getPrecision(), EvqParamIn,
256 arg.getNominalSize(), arg.getSecondarySize());
257 TVariable *var = CreateTempVariable(&mSymbolTable, argType);
258 function->addParameter(var);
259 }
260
261 // Build a return statement for the function that
262 // converts the arguments into the required type.
263 TIntermSequence *returnCtorArgs = new TIntermSequence();
264
265 if (type.isScalar())
266 {
267 AppendScalarFromNonScalarArguments(*function, returnCtorArgs);
268 }
269 else if (type.isVector())
270 {
271 if (arguments.size() == 1 && arg0Type.isScalar())
272 {
273 AppendVectorFromScalarArgument(type, *function, returnCtorArgs);
274 }
275 else
276 {
277 AppendValuesFromMultipleArguments(type.getNominalSize(), *function, returnCtorArgs);
278 }
279 }
280 else if (type.isMatrix())
281 {
282 if (arguments.size() == 1 && arg0Type.isScalar())
283 {
284 // MSL already handles this case
285 AppendMatrixFromScalarArgument(type, *function, returnCtorArgs);
286 }
287 else if (arg0Type.isMatrix())
288 {
289 AppendMatrixFromMatrixArgument(type, *function, returnCtorArgs);
290 }
291 else
292 {
293 AppendValuesFromMultipleArguments(type.getNominalSize() * type.getSecondarySize(),
294 *function, returnCtorArgs);
295 }
296 }
297
298 TIntermBranch *returnStatement =
299 new TIntermBranch(EOpReturn, TIntermAggregate::CreateConstructor(type, returnCtorArgs));
300 body->appendStatement(returnStatement);
301
302 TIntermFunctionDefinition *functionDefinition =
303 CreateInternalFunctionDefinitionNode(*function, body);
304 mFunctionDefs.push_back(functionDefinition);
305
306 TIntermTyped *functionCall = TIntermAggregate::CreateFunctionCall(*function, &arguments);
307
308 return *functionCall;
309 }
310
rewrite(TIntermBlock & root)311 bool rewrite(TIntermBlock &root)
312 {
313 if (!rebuildInPlace(root))
314 {
315 return true;
316 }
317
318 size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(&root);
319 for (TIntermFunctionDefinition *functionDefinition : mFunctionDefs)
320 {
321 root.insertChildNodes(firstFunctionIndex, TIntermSequence({functionDefinition}));
322 }
323
324 return mCompiler.validateAST(&root);
325 }
326
327 private:
328 TVector<TIntermFunctionDefinition *> mFunctionDefs;
329 };
330
331 } // anonymous namespace
332
ConvertUnsupportedConstructorsToFunctionCalls(TCompiler & compiler,TIntermBlock & root)333 bool sh::ConvertUnsupportedConstructorsToFunctionCalls(TCompiler &compiler, TIntermBlock &root)
334 {
335 return Rebuild(compiler).rewrite(root);
336 }
337