1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_ops/msl/ConvertUnsupportedConstructorsToFunctionCalls.h"
8 
9 #include "compiler/translator/ImmutableString.h"
10 #include "compiler/translator/IntermRebuild.h"
11 #include "compiler/translator/Symbol.h"
12 #include "compiler/translator/tree_util/FindFunction.h"
13 #include "compiler/translator/tree_util/IntermNode_util.h"
14 
15 using namespace sh;
16 
17 namespace
18 {
19 
AppendMatrixElementArgument(TIntermSymbol * parameter,int colIndex,int rowIndex,TIntermSequence * returnCtorArgs)20 void AppendMatrixElementArgument(TIntermSymbol *parameter,
21                                  int colIndex,
22                                  int rowIndex,
23                                  TIntermSequence *returnCtorArgs)
24 {
25     TIntermBinary *matColN =
26         new TIntermBinary(EOpIndexDirect, parameter->deepCopy(), CreateIndexNode(colIndex));
27     TIntermSwizzle *matElem = new TIntermSwizzle(matColN, {rowIndex});
28     returnCtorArgs->push_back(matElem);
29 }
30 
31 // Adds the argument to sequence for a scalar constructor.
32 // Given scalar(scalarA) appends scalarA
33 // Given scalar(vecA) appends vecA.x
34 // Given scalar(matA) appends matA[0].x
AppendScalarFromNonScalarArguments(TFunction & function,TIntermSequence * returnCtorArgs)35 void AppendScalarFromNonScalarArguments(TFunction &function, TIntermSequence *returnCtorArgs)
36 {
37     const TVariable *var = function.getParam(0);
38     TIntermSymbol *arg0  = new TIntermSymbol(var);
39 
40     const TType &type = arg0->getType();
41 
42     if (type.isScalar())
43     {
44         returnCtorArgs->push_back(arg0);
45     }
46     else if (type.isVector())
47     {
48         TIntermSwizzle *vecX = new TIntermSwizzle(arg0, {0});
49         returnCtorArgs->push_back(vecX);
50     }
51     else if (type.isMatrix())
52     {
53         AppendMatrixElementArgument(arg0, 0, 0, returnCtorArgs);
54     }
55 }
56 
57 // Adds the arguments to sequence for a vector constructor from a scalar.
58 // Given vecN(scalarA) appends scalarA, scalarA, ... n times
AppendVectorFromScalarArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)59 void AppendVectorFromScalarArgument(const TType &type,
60                                     TFunction &function,
61                                     TIntermSequence *returnCtorArgs)
62 {
63     const uint8_t vectorSize = type.getNominalSize();
64     const TVariable *var     = function.getParam(0);
65     TIntermSymbol *v         = new TIntermSymbol(var);
66     for (uint8_t i = 0; i < vectorSize; ++i)
67     {
68         returnCtorArgs->push_back(v->deepCopy());
69     }
70 }
71 
72 // Adds the arguments to sequence for a vector or matrix constructor from the available arguments
73 // applying arguments in order until the requested number of values have been extracted from the
74 // given arguments or until there are no more arguments.
AppendValuesFromMultipleArguments(int numValuesNeeded,TFunction & function,TIntermSequence * returnCtorArgs)75 void AppendValuesFromMultipleArguments(int numValuesNeeded,
76                                        TFunction &function,
77                                        TIntermSequence *returnCtorArgs)
78 {
79     size_t numParameters = function.getParamCount();
80     size_t paramIndex    = 0;
81     uint8_t colIndex     = 0;
82     uint8_t rowIndex     = 0;
83 
84     for (int i = 0; i < numValuesNeeded && paramIndex < numParameters; ++i)
85     {
86         const TVariable *p       = function.getParam(paramIndex);
87         TIntermSymbol *parameter = new TIntermSymbol(p);
88         if (parameter->isScalar())
89         {
90             returnCtorArgs->push_back(parameter);
91             ++paramIndex;
92         }
93         else if (parameter->isVector())
94         {
95             TIntermSwizzle *vecS = new TIntermSwizzle(parameter->deepCopy(), {rowIndex++});
96             returnCtorArgs->push_back(vecS);
97             if (rowIndex == parameter->getNominalSize())
98             {
99                 ++paramIndex;
100                 rowIndex = 0;
101             }
102         }
103         else if (parameter->isMatrix())
104         {
105             AppendMatrixElementArgument(parameter, colIndex, rowIndex++, returnCtorArgs);
106             if (rowIndex == parameter->getSecondarySize())
107             {
108                 rowIndex = 0;
109                 ++colIndex;
110                 if (colIndex == parameter->getNominalSize())
111                 {
112                     colIndex = 0;
113                     ++paramIndex;
114                 }
115             }
116         }
117     }
118 }
119 
120 // Adds the arguments for a matrix constructor from a scalar
121 // putting the scalar along the diagonal and 0 everywhere else.
AppendMatrixFromScalarArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)122 void AppendMatrixFromScalarArgument(const TType &type,
123                                     TFunction &function,
124                                     TIntermSequence *returnCtorArgs)
125 {
126     const TVariable *var  = function.getParam(0);
127     TIntermSymbol *v      = new TIntermSymbol(var);
128     const uint8_t numCols = type.getNominalSize();
129     const uint8_t numRows = type.getSecondarySize();
130     for (uint8_t col = 0; col < numCols; ++col)
131     {
132         for (uint8_t row = 0; row < numRows; ++row)
133         {
134             if (col == row)
135             {
136                 returnCtorArgs->push_back(v->deepCopy());
137             }
138             else
139             {
140                 returnCtorArgs->push_back(CreateFloatNode(0.0f, sh::EbpUndefined));
141             }
142         }
143     }
144 }
145 
146 // Add the argument for a matrix constructor from a matrix
147 // copying elements from the same column/row and otherwise
148 // initialize to the identity matrix.
AppendMatrixFromMatrixArgument(const TType & type,TFunction & function,TIntermSequence * returnCtorArgs)149 void AppendMatrixFromMatrixArgument(const TType &type,
150                                     TFunction &function,
151                                     TIntermSequence *returnCtorArgs)
152 {
153     const TVariable *var  = function.getParam(0);
154     TIntermSymbol *v      = new TIntermSymbol(var);
155     const uint8_t dstCols = type.getNominalSize();
156     const uint8_t dstRows = type.getSecondarySize();
157     const uint8_t srcCols = v->getNominalSize();
158     const uint8_t srcRows = v->getSecondarySize();
159     for (uint8_t dstCol = 0; dstCol < dstCols; ++dstCol)
160     {
161         for (uint8_t dstRow = 0; dstRow < dstRows; ++dstRow)
162         {
163             if (dstRow < srcRows && dstCol < srcCols)
164             {
165                 AppendMatrixElementArgument(v, dstCol, dstRow, returnCtorArgs);
166             }
167             else
168             {
169                 returnCtorArgs->push_back(
170                     CreateFloatNode(dstRow == dstCol ? 1.0f : 0.0f, sh::EbpUndefined));
171             }
172         }
173     }
174 }
175 
176 class Rebuild : public TIntermRebuild
177 {
178   public:
Rebuild(TCompiler & compiler)179     explicit Rebuild(TCompiler &compiler) : TIntermRebuild(compiler, false, true) {}
visitAggregatePost(TIntermAggregate & node)180     PostResult visitAggregatePost(TIntermAggregate &node) override
181     {
182         if (!node.isConstructor())
183         {
184             return node;
185         }
186 
187         TIntermSequence &arguments = *node.getSequence();
188         if (arguments.empty())
189         {
190             return node;
191         }
192 
193         const TType &type     = node.getType();
194         const TType &arg0Type = arguments[0]->getAsTyped()->getType();
195 
196         if (!type.isScalar() && !type.isVector() && !type.isMatrix())
197         {
198             return node;
199         }
200 
201         if (type.isArray())
202         {
203             return node;
204         }
205 
206         // check for type_ctor(sameType)
207         // scalar(scalar) -> passthrough
208         // vecN(vecN) -> passthrough
209         // matN(matN) -> passthrough
210         if (arguments.size() == 1 && arg0Type == type)
211         {
212             return node;
213         }
214 
215         // The following are simple casts:
216         //
217         // - basic(s) (where basic is int, uint, float or bool, and s is scalar).
218         // - gvecN(vN) (where the argument is a single vector with the same number of components).
219         // - matNxM(mNxM) (where the argument is a single matrix with the same dimensions).  Note
220         // that
221         //   matrices are always float, so there's no actual cast and this would be a no-op.
222         //
223         const bool isSingleScalarCast =
224             arguments.size() == 1 && type.isScalar() && arg0Type.isScalar();
225         const bool isSingleVectorCast = arguments.size() == 1 && type.isVector() &&
226                                         arg0Type.isVector() &&
227                                         type.getNominalSize() == arg0Type.getNominalSize();
228         const bool isSingleMatrixCast =
229             arguments.size() == 1 && type.isMatrix() && arg0Type.isMatrix() &&
230             type.getCols() == arg0Type.getCols() && type.getRows() == arg0Type.getRows();
231         if (isSingleScalarCast || isSingleVectorCast || isSingleMatrixCast)
232         {
233             return node;
234         }
235 
236         // Cases we need to handle:
237         // scalar(vec)
238         // scalar(mat)
239         // vecN(scalar)
240         // vecN(vecM)
241         // vecN(a,...)
242         // matN(scalar) -> diag
243         // matN(vec) -> fail!
244         // manN(matM) -> corner + ident
245         // matN(a, ...)
246 
247         // Build a function and pass all the constructor's arguments to it.
248         TIntermBlock *body  = new TIntermBlock;
249         TFunction *function = new TFunction(&mSymbolTable, ImmutableString(""),
250                                             SymbolType::AngleInternal, &type, true);
251 
252         for (size_t i = 0; i < arguments.size(); ++i)
253         {
254             TIntermTyped &arg = *arguments[i]->getAsTyped();
255             TType *argType    = new TType(arg.getBasicType(), arg.getPrecision(), EvqParamIn,
256                                           arg.getNominalSize(), arg.getSecondarySize());
257             TVariable *var    = CreateTempVariable(&mSymbolTable, argType);
258             function->addParameter(var);
259         }
260 
261         // Build a return statement for the function that
262         // converts the arguments into the required type.
263         TIntermSequence *returnCtorArgs = new TIntermSequence();
264 
265         if (type.isScalar())
266         {
267             AppendScalarFromNonScalarArguments(*function, returnCtorArgs);
268         }
269         else if (type.isVector())
270         {
271             if (arguments.size() == 1 && arg0Type.isScalar())
272             {
273                 AppendVectorFromScalarArgument(type, *function, returnCtorArgs);
274             }
275             else
276             {
277                 AppendValuesFromMultipleArguments(type.getNominalSize(), *function, returnCtorArgs);
278             }
279         }
280         else if (type.isMatrix())
281         {
282             if (arguments.size() == 1 && arg0Type.isScalar())
283             {
284                 // MSL already handles this case
285                 AppendMatrixFromScalarArgument(type, *function, returnCtorArgs);
286             }
287             else if (arg0Type.isMatrix())
288             {
289                 AppendMatrixFromMatrixArgument(type, *function, returnCtorArgs);
290             }
291             else
292             {
293                 AppendValuesFromMultipleArguments(type.getNominalSize() * type.getSecondarySize(),
294                                                   *function, returnCtorArgs);
295             }
296         }
297 
298         TIntermBranch *returnStatement =
299             new TIntermBranch(EOpReturn, TIntermAggregate::CreateConstructor(type, returnCtorArgs));
300         body->appendStatement(returnStatement);
301 
302         TIntermFunctionDefinition *functionDefinition =
303             CreateInternalFunctionDefinitionNode(*function, body);
304         mFunctionDefs.push_back(functionDefinition);
305 
306         TIntermTyped *functionCall = TIntermAggregate::CreateFunctionCall(*function, &arguments);
307 
308         return *functionCall;
309     }
310 
rewrite(TIntermBlock & root)311     bool rewrite(TIntermBlock &root)
312     {
313         if (!rebuildInPlace(root))
314         {
315             return true;
316         }
317 
318         size_t firstFunctionIndex = FindFirstFunctionDefinitionIndex(&root);
319         for (TIntermFunctionDefinition *functionDefinition : mFunctionDefs)
320         {
321             root.insertChildNodes(firstFunctionIndex, TIntermSequence({functionDefinition}));
322         }
323 
324         return mCompiler.validateAST(&root);
325     }
326 
327   private:
328     TVector<TIntermFunctionDefinition *> mFunctionDefs;
329 };
330 
331 }  // anonymous namespace
332 
ConvertUnsupportedConstructorsToFunctionCalls(TCompiler & compiler,TIntermBlock & root)333 bool sh::ConvertUnsupportedConstructorsToFunctionCalls(TCompiler &compiler, TIntermBlock &root)
334 {
335     return Rebuild(compiler).rewrite(root);
336 }
337