xref: /aosp_15_r20/external/angle/src/compiler/translator/msl/EmitMetal.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <cctype>
8 #include <map>
9 
10 #include "common/system_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/Name.h"
14 #include "compiler/translator/OutputTree.h"
15 #include "compiler/translator/SymbolTable.h"
16 #include "compiler/translator/msl/AstHelpers.h"
17 #include "compiler/translator/msl/DebugSink.h"
18 #include "compiler/translator/msl/EmitMetal.h"
19 #include "compiler/translator/msl/Layout.h"
20 #include "compiler/translator/msl/ProgramPrelude.h"
21 #include "compiler/translator/msl/RewritePipelines.h"
22 #include "compiler/translator/msl/TranslatorMSL.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 
25 using namespace sh;
26 
27 ////////////////////////////////////////////////////////////////////////////////
28 
29 #if defined(ANGLE_ENABLE_ASSERTS)
30 using Sink = DebugSink;
31 #else
32 using Sink = TInfoSinkBase;
33 #endif
34 
35 ////////////////////////////////////////////////////////////////////////////////
36 
37 namespace
38 {
39 
40 struct VarDecl
41 {
VarDecl__anon45ec70e50111::VarDecl42     explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anon45ec70e50111::VarDecl43     explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
44 
variable__anon45ec70e50111::VarDecl45     ANGLE_INLINE const TVariable &variable() const
46     {
47         ASSERT(isVariable());
48         return *mVariable;
49     }
50 
field__anon45ec70e50111::VarDecl51     ANGLE_INLINE const TField &field() const
52     {
53         ASSERT(isField());
54         return *mField;
55     }
56 
isVariable__anon45ec70e50111::VarDecl57     ANGLE_INLINE bool isVariable() const { return !mIsField; }
58 
isField__anon45ec70e50111::VarDecl59     ANGLE_INLINE bool isField() const { return mIsField; }
60 
type__anon45ec70e50111::VarDecl61     const TType &type() const { return isField() ? *field().type() : variable().getType(); }
62 
symbolType__anon45ec70e50111::VarDecl63     SymbolType symbolType() const
64     {
65         return isField() ? field().symbolType() : variable().symbolType();
66     }
67 
68   private:
69     union
70     {
71         const TVariable *mVariable;
72         const TField *mField;
73     };
74     bool mIsField;
75 };
76 
77 class GenMetalTraverser : public TIntermTraverser
78 {
79   public:
80     ~GenMetalTraverser() override;
81 
82     GenMetalTraverser(const TCompiler &compiler,
83                       Sink &out,
84                       IdGen &idGen,
85                       const PipelineStructs &pipelineStructs,
86                       SymbolEnv &symbolEnv,
87                       const ShCompileOptions &compileOptions);
88 
89     void visitSymbol(TIntermSymbol *) override;
90     void visitConstantUnion(TIntermConstantUnion *) override;
91     bool visitSwizzle(Visit, TIntermSwizzle *) override;
92     bool visitBinary(Visit, TIntermBinary *) override;
93     bool visitUnary(Visit, TIntermUnary *) override;
94     bool visitTernary(Visit, TIntermTernary *) override;
95     bool visitIfElse(Visit, TIntermIfElse *) override;
96     bool visitSwitch(Visit, TIntermSwitch *) override;
97     bool visitCase(Visit, TIntermCase *) override;
98     void visitFunctionPrototype(TIntermFunctionPrototype *) override;
99     bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
100     bool visitAggregate(Visit, TIntermAggregate *) override;
101     bool visitBlock(Visit, TIntermBlock *) override;
102     bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
103     bool visitDeclaration(Visit, TIntermDeclaration *) override;
104     bool visitLoop(Visit, TIntermLoop *) override;
105     bool visitForLoop(TIntermLoop *);
106     bool visitWhileLoop(TIntermLoop *);
107     bool visitDoWhileLoop(TIntermLoop *);
108     bool visitBranch(Visit, TIntermBranch *) override;
109 
110   private:
111     using FuncToName = std::map<ImmutableString, Name>;
112     static FuncToName BuildFuncToName();
113 
114     struct EmitVariableDeclarationConfig
115     {
116         bool isParameter                = false;
117         bool isMainParameter            = false;
118         bool emitPostQualifier          = false;
119         bool isPacked                   = false;
120         bool disableStructSpecifier     = false;
121         bool isUBO                      = false;
122         const AddressSpace *isPointer   = nullptr;
123         const AddressSpace *isReference = nullptr;
124     };
125 
126     struct EmitTypeConfig
127     {
128         const EmitVariableDeclarationConfig *evdConfig = nullptr;
129     };
130 
131     void emitIndentation();
132     void emitOpeningPointerParen();
133     void emitClosingPointerParen();
134     void emitFunctionSignature(const TFunction &func);
135     void emitFunctionReturn(const TFunction &func);
136     void emitFunctionParameter(const TFunction &func, const TVariable &param);
137 
138     void emitNameOf(const TField &object);
139     void emitNameOf(const TSymbol &object);
140     void emitNameOf(const VarDecl &object);
141 
142     void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
143     void emitType(const TType &type, const EmitTypeConfig &etConfig);
144     void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
145                            const VarDecl &decl,
146                            const TQualifier qualifier);
147 
148     void emitLoopBody(TIntermBlock *bodyNode);
149 
150     struct FieldAnnotationIndices
151     {
152         size_t attribute = 0;
153         size_t color     = 0;
154     };
155 
156     void emitFieldDeclaration(const TField &field,
157                               const TStructure &parent,
158                               FieldAnnotationIndices &annotationIndices);
159     void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
160     void emitUniformBufferDeclaration(const TField &field,
161                                       FieldAnnotationIndices &annotationIndices);
162     void emitStructDeclaration(const TType &type);
163     void emitOrdinaryVariableDeclaration(const VarDecl &decl,
164                                          const EmitVariableDeclarationConfig &evdConfig);
165     void emitVariableDeclaration(const VarDecl &decl,
166                                  const EmitVariableDeclarationConfig &evdConfig);
167 
168     void emitOpenBrace();
169     void emitCloseBrace();
170 
171     void groupedTraverse(TIntermNode &node);
172 
173     const TField &getDirectField(const TFieldListCollection &fieldsNode,
174                                  const TConstantUnion &index);
175     const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
176 
177     const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
178                                                  const size_t size);
179 
180     const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
181 
182     void emitSingleConstant(const TConstantUnion *const constUnion);
183 
184   private:
185     Sink &mOut;
186     const TCompiler &mCompiler;
187     const PipelineStructs &mPipelineStructs;
188     SymbolEnv &mSymbolEnv;
189     IdGen &mIdGen;
190     int mIndentLevel                  = -1;
191     int mLastIndentationPos           = -1;
192     int mOpenPointerParenCount        = 0;
193     bool mParentIsSwitch              = false;
194     bool isTraversingVertexMain       = false;
195     bool mTemporarilyDisableSemicolon = false;
196     std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
197     const FuncToName mFuncToName           = BuildFuncToName();
198     size_t mMainTextureIndex               = 0;
199     size_t mMainSamplerIndex               = 0;
200     size_t mMainUniformBufferIndex         = 0;
201     size_t mDriverUniformsBindingIndex     = 0;
202     size_t mUBOArgumentBufferBindingIndex  = 0;
203     bool mRasterOrderGroupsSupported       = false;
204     bool mInjectAsmStatementIntoLoopBodies = false;
205 };
206 }  // anonymous namespace
207 
~GenMetalTraverser()208 GenMetalTraverser::~GenMetalTraverser()
209 {
210     ASSERT(mIndentLevel == -1);
211     ASSERT(!mParentIsSwitch);
212     ASSERT(mOpenPointerParenCount == 0);
213 }
214 
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ShCompileOptions & compileOptions)215 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
216                                      Sink &out,
217                                      IdGen &idGen,
218                                      const PipelineStructs &pipelineStructs,
219                                      SymbolEnv &symbolEnv,
220                                      const ShCompileOptions &compileOptions)
221     : TIntermTraverser(true, false, false),
222       mOut(out),
223       mCompiler(compiler),
224       mPipelineStructs(pipelineStructs),
225       mSymbolEnv(symbolEnv),
226       mIdGen(idGen),
227       mMainUniformBufferIndex(compileOptions.metal.defaultUniformsBindingIndex),
228       mDriverUniformsBindingIndex(compileOptions.metal.driverUniformsBindingIndex),
229       mUBOArgumentBufferBindingIndex(compileOptions.metal.UBOArgumentBufferBindingIndex),
230       mRasterOrderGroupsSupported(compileOptions.pls.fragmentSyncType ==
231                                   ShFragmentSynchronizationType::RasterOrderGroups_Metal),
232       mInjectAsmStatementIntoLoopBodies(compileOptions.metal.injectAsmStatementIntoLoopBodies)
233 {}
234 
emitIndentation()235 void GenMetalTraverser::emitIndentation()
236 {
237     ASSERT(mIndentLevel >= 0);
238 
239     if (mLastIndentationPos == mOut.size())
240     {
241         return;  // Line is already indented.
242     }
243 
244     for (int i = 0; i < mIndentLevel; ++i)
245     {
246         mOut << "  ";
247     }
248 
249     mLastIndentationPos = mOut.size();
250 }
251 
emitOpeningPointerParen()252 void GenMetalTraverser::emitOpeningPointerParen()
253 {
254     mOut << "(*";
255     mOpenPointerParenCount++;
256 }
257 
emitClosingPointerParen()258 void GenMetalTraverser::emitClosingPointerParen()
259 {
260     if (mOpenPointerParenCount > 0)
261     {
262         mOut << ")";
263         mOpenPointerParenCount--;
264     }
265 }
266 
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)267 static const char *GetOperatorString(TOperator op,
268                                      const TType &resultType,
269                                      const TType *argType0,
270                                      const TType *argType1,
271                                      const TType *argType2)
272 {
273     switch (op)
274     {
275         case TOperator::EOpComma:
276             return ",";
277         case TOperator::EOpAssign:
278             return "=";
279         case TOperator::EOpInitialize:
280             return "=";
281         case TOperator::EOpAddAssign:
282             return "+=";
283         case TOperator::EOpSubAssign:
284             return "-=";
285         case TOperator::EOpMulAssign:
286             return "*=";
287         case TOperator::EOpDivAssign:
288             return "/=";
289         case TOperator::EOpIModAssign:
290             return "%=";
291         case TOperator::EOpBitShiftLeftAssign:
292             return "<<=";  // TODO: Check logical vs arithmetic shifting.
293         case TOperator::EOpBitShiftRightAssign:
294             return ">>=";  // TODO: Check logical vs arithmetic shifting.
295         case TOperator::EOpBitwiseAndAssign:
296             return "&=";
297         case TOperator::EOpBitwiseXorAssign:
298             return "^=";
299         case TOperator::EOpBitwiseOrAssign:
300             return "|=";
301         case TOperator::EOpAdd:
302             return "+";
303         case TOperator::EOpSub:
304             return "-";
305         case TOperator::EOpMul:
306             return "*";
307         case TOperator::EOpDiv:
308             return "/";
309         case TOperator::EOpIMod:
310             return "%";
311         case TOperator::EOpBitShiftLeft:
312             return "<<";  // TODO: Check logical vs arithmetic shifting.
313         case TOperator::EOpBitShiftRight:
314             return ">>";  // TODO: Check logical vs arithmetic shifting.
315         case TOperator::EOpBitwiseAnd:
316             return "&";
317         case TOperator::EOpBitwiseXor:
318             return "^";
319         case TOperator::EOpBitwiseOr:
320             return "|";
321         case TOperator::EOpLessThan:
322             return "<";
323         case TOperator::EOpGreaterThan:
324             return ">";
325         case TOperator::EOpLessThanEqual:
326             return "<=";
327         case TOperator::EOpGreaterThanEqual:
328             return ">=";
329         case TOperator::EOpLessThanComponentWise:
330             return "<";
331         case TOperator::EOpLessThanEqualComponentWise:
332             return "<=";
333         case TOperator::EOpGreaterThanEqualComponentWise:
334             return ">=";
335         case TOperator::EOpGreaterThanComponentWise:
336             return ">";
337         case TOperator::EOpLogicalOr:
338             return "||";
339         case TOperator::EOpLogicalXor:
340             return "!=/*xor*/";  // XXX: This might need to be handled differently for some obtuse
341                                  // use case.
342         case TOperator::EOpLogicalAnd:
343             return "&&";
344         case TOperator::EOpNegative:
345             return "-";
346         case TOperator::EOpPositive:
347             if (argType0->isMatrix())
348             {
349                 return "";
350             }
351             return "+";
352         case TOperator::EOpLogicalNot:
353             return "!";
354         case TOperator::EOpNotComponentWise:
355             return "!";
356         case TOperator::EOpBitwiseNot:
357             return "~";
358         case TOperator::EOpPostIncrement:
359             return "++";
360         case TOperator::EOpPostDecrement:
361             return "--";
362         case TOperator::EOpPreIncrement:
363             return "++";
364         case TOperator::EOpPreDecrement:
365             return "--";
366         case TOperator::EOpVectorTimesScalarAssign:
367             return "*=";
368         case TOperator::EOpVectorTimesMatrixAssign:
369             return "*=";
370         case TOperator::EOpMatrixTimesScalarAssign:
371             return "*=";
372         case TOperator::EOpMatrixTimesMatrixAssign:
373             return "*=";
374         case TOperator::EOpVectorTimesScalar:
375             return "*";
376         case TOperator::EOpVectorTimesMatrix:
377             return "*";
378         case TOperator::EOpMatrixTimesVector:
379             return "*";
380         case TOperator::EOpMatrixTimesScalar:
381             return "*";
382         case TOperator::EOpMatrixTimesMatrix:
383             return "*";
384         case TOperator::EOpEqualComponentWise:
385             return "==";
386         case TOperator::EOpNotEqualComponentWise:
387             return "!=";
388 
389         case TOperator::EOpEqual:
390             if ((argType0->getStruct() && argType1->getStruct()) &&
391                 (argType0->isArray() && argType1->isArray()))
392             {
393                 return "ANGLE_equalStructArray";
394             }
395 
396             if ((argType0->isVector() && argType1->isVector()) ||
397                 (argType0->getStruct() && argType1->getStruct()) ||
398                 (argType0->isArray() && argType1->isArray()) ||
399                 (argType0->isMatrix() && argType1->isMatrix()))
400 
401             {
402                 return "ANGLE_equal";
403             }
404 
405             return "==";
406 
407         case TOperator::EOpNotEqual:
408             if ((argType0->getStruct() && argType1->getStruct()) &&
409                 (argType0->isArray() && argType1->isArray()))
410             {
411                 return "ANGLE_notEqualStructArray";
412             }
413 
414             if ((argType0->isVector() && argType1->isVector()) ||
415                 (argType0->isArray() && argType1->isArray()) ||
416                 (argType0->isMatrix() && argType1->isMatrix()))
417             {
418                 return "ANGLE_notEqual";
419             }
420             else if (argType0->getStruct() && argType1->getStruct())
421             {
422                 return "ANGLE_notEqualStruct";
423             }
424             return "!=";
425 
426         case TOperator::EOpKill:
427             UNIMPLEMENTED();
428             return "kill";
429         case TOperator::EOpReturn:
430             return "return";
431         case TOperator::EOpBreak:
432             return "break";
433         case TOperator::EOpContinue:
434             return "continue";
435 
436         case TOperator::EOpRadians:
437             return "ANGLE_radians";
438         case TOperator::EOpDegrees:
439             return "ANGLE_degrees";
440         case TOperator::EOpAtan:
441             return argType1 == nullptr ? "metal::atan" : "metal::atan2";
442         case TOperator::EOpMod:
443             return "ANGLE_mod";  // differs from metal::mod
444         case TOperator::EOpRefract:
445             return argType0->isVector() ? "metal::refract" : "ANGLE_refract_scalar";
446         case TOperator::EOpDistance:
447             return argType0->isVector() ? "metal::distance" : "ANGLE_distance_scalar";
448         case TOperator::EOpLength:
449             return argType0->isVector() ? "metal::length" : "metal::abs";
450         case TOperator::EOpDot:
451             return argType0->isVector() ? "metal::dot" : "*";
452         case TOperator::EOpNormalize:
453             return argType0->isVector() ? "metal::fast::normalize" : "metal::sign";
454         case TOperator::EOpFaceforward:
455             return argType0->isVector() ? "metal::faceforward" : "ANGLE_faceforward_scalar";
456         case TOperator::EOpReflect:
457             return argType0->isVector() ? "metal::reflect" : "ANGLE_reflect_scalar";
458         case TOperator::EOpMatrixCompMult:
459             return "ANGLE_componentWiseMultiply";
460         case TOperator::EOpOuterProduct:
461             return "ANGLE_outerProduct";
462         case TOperator::EOpSign:
463             return argType0->getBasicType() == EbtFloat ? "metal::sign" : "ANGLE_sign_int";
464 
465         case TOperator::EOpAbs:
466             return "metal::abs";
467         case TOperator::EOpAll:
468             return "metal::all";
469         case TOperator::EOpAny:
470             return "metal::any";
471         case TOperator::EOpSin:
472             return "metal::sin";
473         case TOperator::EOpCos:
474             return "metal::cos";
475         case TOperator::EOpTan:
476             return "metal::tan";
477         case TOperator::EOpAsin:
478             return "metal::asin";
479         case TOperator::EOpAcos:
480             return "metal::acos";
481         case TOperator::EOpSinh:
482             return "metal::sinh";
483         case TOperator::EOpCosh:
484             return "metal::cosh";
485         case TOperator::EOpTanh:
486             return resultType.getPrecision() == TPrecision::EbpHigh ? "metal::precise::tanh"
487                                                                     : "metal::tanh";
488         case TOperator::EOpAsinh:
489             return "metal::asinh";
490         case TOperator::EOpAcosh:
491             return "metal::acosh";
492         case TOperator::EOpAtanh:
493             return "metal::atanh";
494         case TOperator::EOpFma:
495             return "metal::fma";
496         case TOperator::EOpPow:
497             return "metal::powr";  // GLSL's pow excludes negative x
498         case TOperator::EOpExp:
499             return "metal::exp";
500         case TOperator::EOpExp2:
501             return "metal::exp2";
502         case TOperator::EOpLog:
503             return "metal::log";
504         case TOperator::EOpLog2:
505             return "metal::log2";
506         case TOperator::EOpSqrt:
507             return "metal::sqrt";
508         case TOperator::EOpFloor:
509             return "metal::floor";
510         case TOperator::EOpTrunc:
511             return "metal::trunc";
512         case TOperator::EOpCeil:
513             return "metal::ceil";
514         case TOperator::EOpFract:
515             return "metal::fract";
516         case TOperator::EOpMin:
517             return "metal::min";
518         case TOperator::EOpMax:
519             return "metal::max";
520         case TOperator::EOpRound:
521             return "metal::round";
522         case TOperator::EOpRoundEven:
523             return "metal::rint";
524         case TOperator::EOpClamp:
525             return "metal::clamp";  // TODO fast vs precise namespace
526         case TOperator::EOpSaturate:
527             return "metal::saturate";  // TODO fast vs precise namespace
528         case TOperator::EOpMix:
529             if (!argType1->isScalar() && argType2 && argType2->getBasicType() == EbtBool)
530             {
531                 return "ANGLE_mix_bool";
532             }
533             return "metal::mix";
534         case TOperator::EOpStep:
535             return "metal::step";
536         case TOperator::EOpSmoothstep:
537             return "metal::smoothstep";
538         case TOperator::EOpModf:
539             return "metal::modf";
540         case TOperator::EOpIsnan:
541             return "metal::isnan";
542         case TOperator::EOpIsinf:
543             return "metal::isinf";
544         case TOperator::EOpLdexp:
545             return "metal::ldexp";
546         case TOperator::EOpFrexp:
547             return "metal::frexp";
548         case TOperator::EOpInversesqrt:
549             return "metal::rsqrt";
550         case TOperator::EOpCross:
551             return "metal::cross";
552         case TOperator::EOpDFdx:
553             return "metal::dfdx";
554         case TOperator::EOpDFdy:
555             return "metal::dfdy";
556         case TOperator::EOpFwidth:
557             return "metal::fwidth";
558         case TOperator::EOpTranspose:
559             return "metal::transpose";
560         case TOperator::EOpDeterminant:
561             return "metal::determinant";
562 
563         case TOperator::EOpInverse:
564             return "ANGLE_inverse";
565 
566         case TOperator::EOpInterpolateAtCentroid:
567             return "ANGLE_interpolateAtCentroid";
568         case TOperator::EOpInterpolateAtSample:
569             return "ANGLE_interpolateAtSample";
570         case TOperator::EOpInterpolateAtOffset:
571             return "ANGLE_interpolateAtOffset";
572         case TOperator::EOpInterpolateAtCenter:
573             return "ANGLE_interpolateAtCenter";
574 
575         case TOperator::EOpFloatBitsToInt:
576         case TOperator::EOpFloatBitsToUint:
577         case TOperator::EOpIntBitsToFloat:
578         case TOperator::EOpUintBitsToFloat:
579         {
580 #define RETURN_AS_TYPE_SCALAR()             \
581     do                                      \
582         switch (resultType.getBasicType())  \
583         {                                   \
584             case TBasicType::EbtInt:        \
585                 return "as_type<int>";      \
586             case TBasicType::EbtUInt:       \
587                 return "as_type<uint32_t>"; \
588             case TBasicType::EbtFloat:      \
589                 return "as_type<float>";    \
590             default:                        \
591                 UNIMPLEMENTED();            \
592                 return "TOperator_TODO";    \
593         }                                   \
594     while (false)
595 
596 #define RETURN_AS_TYPE(post)                     \
597     do                                           \
598         switch (resultType.getBasicType())       \
599         {                                        \
600             case TBasicType::EbtInt:             \
601                 return "as_type<int" post ">";   \
602             case TBasicType::EbtUInt:            \
603                 return "as_type<uint" post ">";  \
604             case TBasicType::EbtFloat:           \
605                 return "as_type<float" post ">"; \
606             default:                             \
607                 UNIMPLEMENTED();                 \
608                 return "TOperator_TODO";         \
609         }                                        \
610     while (false)
611 
612             if (resultType.isScalar())
613             {
614                 RETURN_AS_TYPE_SCALAR();
615             }
616             else if (resultType.isVector())
617             {
618                 switch (resultType.getNominalSize())
619                 {
620                     case 2:
621                         RETURN_AS_TYPE("2");
622                     case 3:
623                         RETURN_AS_TYPE("3");
624                     case 4:
625                         RETURN_AS_TYPE("4");
626                     default:
627                         UNREACHABLE();
628                         return nullptr;
629                 }
630             }
631             else
632             {
633                 UNIMPLEMENTED();
634                 return "TOperator_TODO";
635             }
636 
637 #undef RETURN_AS_TYPE
638 #undef RETURN_AS_TYPE_SCALAR
639         }
640 
641         case TOperator::EOpPackUnorm2x16:
642             return "metal::pack_float_to_unorm2x16";
643         case TOperator::EOpPackSnorm2x16:
644             return "metal::pack_float_to_snorm2x16";
645 
646         case TOperator::EOpPackUnorm4x8:
647             return "metal::pack_float_to_unorm4x8";
648         case TOperator::EOpPackSnorm4x8:
649             return "metal::pack_float_to_snorm4x8";
650 
651         case TOperator::EOpUnpackUnorm2x16:
652             return "metal::unpack_unorm2x16_to_float";
653         case TOperator::EOpUnpackSnorm2x16:
654             return "metal::unpack_snorm2x16_to_float";
655 
656         case TOperator::EOpUnpackUnorm4x8:
657             return "metal::unpack_unorm4x8_to_float";
658         case TOperator::EOpUnpackSnorm4x8:
659             return "metal::unpack_snorm4x8_to_float";
660 
661         case TOperator::EOpPackHalf2x16:
662             return "ANGLE_pack_half_2x16";
663         case TOperator::EOpUnpackHalf2x16:
664             return "ANGLE_unpack_half_2x16";
665 
666         case TOperator::EOpNumSamples:
667             return "metal::get_num_samples";
668         case TOperator::EOpSamplePosition:
669             return "metal::get_sample_position";
670 
671         case TOperator::EOpBitfieldExtract:
672         case TOperator::EOpBitfieldInsert:
673         case TOperator::EOpBitfieldReverse:
674         case TOperator::EOpBitCount:
675         case TOperator::EOpFindLSB:
676         case TOperator::EOpFindMSB:
677         case TOperator::EOpUaddCarry:
678         case TOperator::EOpUsubBorrow:
679         case TOperator::EOpUmulExtended:
680         case TOperator::EOpImulExtended:
681         case TOperator::EOpBarrier:
682         case TOperator::EOpMemoryBarrier:
683         case TOperator::EOpMemoryBarrierAtomicCounter:
684         case TOperator::EOpMemoryBarrierBuffer:
685         case TOperator::EOpMemoryBarrierShared:
686         case TOperator::EOpGroupMemoryBarrier:
687         case TOperator::EOpAtomicAdd:
688         case TOperator::EOpAtomicMin:
689         case TOperator::EOpAtomicMax:
690         case TOperator::EOpAtomicAnd:
691         case TOperator::EOpAtomicOr:
692         case TOperator::EOpAtomicXor:
693         case TOperator::EOpAtomicExchange:
694         case TOperator::EOpAtomicCompSwap:
695         case TOperator::EOpEmitVertex:
696         case TOperator::EOpEndPrimitive:
697         case TOperator::EOpArrayLength:
698             UNIMPLEMENTED();
699             return "TOperator_TODO";
700 
701         case TOperator::EOpNull:
702         case TOperator::EOpConstruct:
703         case TOperator::EOpCallFunctionInAST:
704         case TOperator::EOpCallInternalRawFunction:
705         case TOperator::EOpIndexDirect:
706         case TOperator::EOpIndexIndirect:
707         case TOperator::EOpIndexDirectStruct:
708         case TOperator::EOpIndexDirectInterfaceBlock:
709             UNREACHABLE();
710             return nullptr;
711         default:
712             // Any other built-in function.
713             return nullptr;
714     }
715 }
716 
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)717 static bool IsSymbolicOperator(TOperator op,
718                                const TType &resultType,
719                                const TType *argType0,
720                                const TType *argType1)
721 {
722     const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
723     if (operatorString == nullptr)
724     {
725         return false;
726     }
727     return !std::isalnum(operatorString[0]);
728 }
729 
AsSpecificBinaryNode(TIntermNode & node,TOperator op)730 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
731 {
732     TIntermBinary *binaryNode = node.getAsBinaryNode();
733     if (binaryNode)
734     {
735         return binaryNode->getOp() == op ? binaryNode : nullptr;
736     }
737     return nullptr;
738 }
739 
Parenthesize(TIntermNode & node)740 static bool Parenthesize(TIntermNode &node)
741 {
742     if (node.getAsSymbolNode())
743     {
744         return false;
745     }
746     if (node.getAsConstantUnion())
747     {
748         return false;
749     }
750     if (node.getAsAggregate())
751     {
752         return false;
753     }
754     if (node.getAsSwizzleNode())
755     {
756         return false;
757     }
758 
759     if (TIntermUnary *unaryNode = node.getAsUnaryNode())
760     {
761         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
762         const TType &resultType = unaryNode->getType();
763         const TType &argType    = unaryNode->getOperand()->getType();
764         return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType, nullptr);
765     }
766 
767     if (TIntermBinary *binaryNode = node.getAsBinaryNode())
768     {
769         // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
770         const TOperator op = binaryNode->getOp();
771         switch (op)
772         {
773             case TOperator::EOpIndexDirectStruct:
774             case TOperator::EOpIndexDirectInterfaceBlock:
775             case TOperator::EOpIndexDirect:
776             case TOperator::EOpIndexIndirect:
777                 return Parenthesize(*binaryNode->getLeft());
778 
779             case TOperator::EOpAssign:
780             case TOperator::EOpInitialize:
781                 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
782 
783             default:
784             {
785                 const TType &resultType = binaryNode->getType();
786                 const TType &leftType   = binaryNode->getLeft()->getType();
787                 const TType &rightType  = binaryNode->getRight()->getType();
788                 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
789             }
790         }
791     }
792 
793     return true;
794 }
795 
groupedTraverse(TIntermNode & node)796 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
797 {
798     const bool emitParens = Parenthesize(node);
799 
800     if (emitParens)
801     {
802         mOut << "(";
803     }
804 
805     node.traverse(this);
806 
807     if (emitParens)
808     {
809         mOut << ")";
810     }
811 }
812 
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)813 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
814                                           const VarDecl &decl,
815                                           const TQualifier qualifier)
816 {
817     bool isInvariant = false;
818     switch (qualifier)
819     {
820         case TQualifier::EvqPosition:
821             isInvariant = decl.type().isInvariant();
822             [[fallthrough]];
823         case TQualifier::EvqFragCoord:
824             mOut << " [[position]]";
825             break;
826 
827         case TQualifier::EvqClipDistance:
828             mOut << " [[clip_distance]] [" << decl.type().getOutermostArraySize() << "]";
829             break;
830 
831         case TQualifier::EvqPointSize:
832             mOut << " [[point_size]]";
833             break;
834 
835         case TQualifier::EvqVertexID:
836             if (evdConfig.isMainParameter)
837             {
838                 mOut << " [[vertex_id]]";
839             }
840             break;
841 
842         case TQualifier::EvqPointCoord:
843             if (evdConfig.isMainParameter)
844             {
845                 mOut << " [[point_coord]]";
846             }
847             break;
848 
849         case TQualifier::EvqFrontFacing:
850             if (evdConfig.isMainParameter)
851             {
852                 mOut << " [[front_facing]]";
853             }
854             break;
855 
856         case TQualifier::EvqSampleID:
857             if (evdConfig.isMainParameter)
858             {
859                 mOut << " [[sample_id]]";
860             }
861             break;
862 
863         case TQualifier::EvqSampleMaskIn:
864             if (evdConfig.isMainParameter)
865             {
866                 mOut << " [[sample_mask]]";
867             }
868             break;
869 
870         default:
871             break;
872     }
873 
874     if (isInvariant)
875     {
876         mOut << " [[invariant]]";
877 
878         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
879         reflection->hasInvariance             = true;
880     }
881 }
882 
emitLoopBody(TIntermBlock * bodyNode)883 void GenMetalTraverser::emitLoopBody(TIntermBlock *bodyNode)
884 {
885     if (mInjectAsmStatementIntoLoopBodies)
886     {
887         emitOpenBrace();
888 
889         emitIndentation();
890         mOut << "__asm__(\"\");\n";
891     }
892 
893     bodyNode->traverse(this);
894 
895     if (mInjectAsmStatementIntoLoopBodies)
896     {
897         emitCloseBrace();
898     }
899 }
900 
EmitName(Sink & out,const Name & name)901 static void EmitName(Sink &out, const Name &name)
902 {
903 #if defined(ANGLE_ENABLE_ASSERTS)
904     DebugSink::EscapedSink escapedOut(out.escape());
905 #else
906     TInfoSinkBase &escapedOut = out;
907 #endif
908     name.emit(escapedOut);
909 }
910 
emitNameOf(const TField & object)911 void GenMetalTraverser::emitNameOf(const TField &object)
912 {
913     EmitName(mOut, Name(object));
914 }
915 
emitNameOf(const TSymbol & object)916 void GenMetalTraverser::emitNameOf(const TSymbol &object)
917 {
918     auto it = mRenamedSymbols.find(&object);
919     if (it == mRenamedSymbols.end())
920     {
921         EmitName(mOut, Name(object));
922     }
923     else
924     {
925         EmitName(mOut, it->second);
926     }
927 }
928 
emitNameOf(const VarDecl & object)929 void GenMetalTraverser::emitNameOf(const VarDecl &object)
930 {
931     if (object.isField())
932     {
933         emitNameOf(object.field());
934     }
935     else
936     {
937         emitNameOf(object.variable());
938     }
939 }
940 
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)941 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
942 {
943     const TBasicType basicType = type.getBasicType();
944 
945     switch (basicType)
946     {
947         case TBasicType::EbtVoid:
948         case TBasicType::EbtBool:
949         case TBasicType::EbtFloat:
950         case TBasicType::EbtInt:
951         {
952             mOut << type.getBasicString();
953             break;
954         }
955         case TBasicType::EbtUInt:
956         {
957             if (type.isScalar())
958             {
959                 mOut << "uint32_t";
960             }
961             else
962             {
963                 mOut << type.getBasicString();
964             }
965         }
966         break;
967 
968         case TBasicType::EbtStruct:
969         {
970             const TStructure &structure = *type.getStruct();
971             emitNameOf(structure);
972         }
973         break;
974 
975         case TBasicType::EbtInterfaceBlock:
976         {
977             const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
978             emitNameOf(interfaceBlock);
979         }
980         break;
981 
982         default:
983         {
984             if (IsSampler(basicType))
985             {
986                 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
987                 {
988                     EmitName(mOut, GetTextureTypeName(basicType));
989                 }
990                 else
991                 {
992                     const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
993                     emitNameOf(env);
994                 }
995             }
996             else if (IsImage(basicType))
997             {
998                 mOut << "metal::texture2d<";
999                 switch (type.getBasicType())
1000                 {
1001                     case EbtImage2D:
1002                         mOut << "float";
1003                         break;
1004                     case EbtIImage2D:
1005                         mOut << "int";
1006                         break;
1007                     case EbtUImage2D:
1008                         mOut << "uint";
1009                         break;
1010                     default:
1011                         UNIMPLEMENTED();
1012                         break;
1013                 }
1014                 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
1015                 {
1016                     UNIMPLEMENTED();
1017                 }
1018                 mOut << ", metal::access::read_write>";
1019             }
1020             else
1021             {
1022                 UNIMPLEMENTED();
1023             }
1024         }
1025     }
1026 }
1027 
emitType(const TType & type,const EmitTypeConfig & etConfig)1028 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
1029 {
1030     const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
1031     if (etConfig.evdConfig)
1032     {
1033         const auto &evdConfig = *etConfig.evdConfig;
1034         if (isUBO)
1035         {
1036             if (type.isArray())
1037             {
1038                 mOut << "metal::array<";
1039             }
1040         }
1041         if (evdConfig.isPointer)
1042         {
1043             mOut << toString(*evdConfig.isPointer);
1044             mOut << " ";
1045         }
1046         else if (evdConfig.isReference)
1047         {
1048             mOut << toString(*evdConfig.isReference);
1049             mOut << " ";
1050         }
1051     }
1052 
1053     if (!isUBO)
1054     {
1055         if (type.isArray())
1056         {
1057             mOut << "metal::array<";
1058         }
1059     }
1060 
1061     if (type.isInterpolant())
1062     {
1063         mOut << "metal::interpolant<";
1064     }
1065 
1066     if (type.isVector() || type.isMatrix())
1067     {
1068         mOut << "metal::";
1069     }
1070 
1071     if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
1072     {
1073         mOut << "packed_";
1074     }
1075 
1076     emitBareTypeName(type, etConfig);
1077 
1078     if (type.isVector())
1079     {
1080         mOut << static_cast<uint32_t>(type.getNominalSize());
1081     }
1082     else if (type.isMatrix())
1083     {
1084         mOut << static_cast<uint32_t>(type.getCols()) << "x"
1085              << static_cast<uint32_t>(type.getRows());
1086     }
1087 
1088     if (type.isInterpolant())
1089     {
1090         mOut << ", metal::interpolation::";
1091         switch (type.getQualifier())
1092         {
1093             case EvqNoPerspectiveIn:
1094             case EvqNoPerspectiveCentroidIn:
1095             case EvqNoPerspectiveSampleIn:
1096                 mOut << "no_";
1097                 break;
1098             default:
1099                 break;
1100         }
1101         mOut << "perspective>";
1102     }
1103 
1104     if (!isUBO)
1105     {
1106         if (type.isArray())
1107         {
1108             for (auto size : type.getArraySizes())
1109             {
1110                 mOut << ", " << size;
1111             }
1112             mOut << ">";
1113         }
1114     }
1115 
1116     if (etConfig.evdConfig)
1117     {
1118         const auto &evdConfig = *etConfig.evdConfig;
1119         if (evdConfig.isPointer)
1120         {
1121             mOut << " *";
1122         }
1123         else if (evdConfig.isReference)
1124         {
1125             mOut << " &";
1126         }
1127         if (isUBO)
1128         {
1129             if (type.isArray())
1130             {
1131                 for (auto size : type.getArraySizes())
1132                 {
1133                     mOut << ", " << size;
1134                 }
1135                 mOut << ">";
1136             }
1137         }
1138     }
1139 }
1140 
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1141 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1142                                              const TStructure &parent,
1143                                              FieldAnnotationIndices &annotationIndices)
1144 {
1145     const TType &type      = *field.type();
1146     const TBasicType basic = type.getBasicType();
1147 
1148     EmitVariableDeclarationConfig evdConfig;
1149     evdConfig.emitPostQualifier      = true;
1150     evdConfig.disableStructSpecifier = true;
1151     evdConfig.isPacked               = mSymbolEnv.isPacked(field);
1152     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1153     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1154     evdConfig.isReference            = mSymbolEnv.isReference(field);
1155     emitVariableDeclaration(VarDecl(field), evdConfig);
1156 
1157     const TQualifier qual = type.getQualifier();
1158     switch (qual)
1159     {
1160         case TQualifier::EvqFlatIn:
1161             if (mPipelineStructs.fragmentIn.external == &parent)
1162             {
1163                 mOut << " [[flat]]";
1164                 TranslatorMetalReflection *reflection =
1165                     mtl::getTranslatorMetalReflection(&mCompiler);
1166                 reflection->hasFlatInput = true;
1167             }
1168             break;
1169 
1170         case TQualifier::EvqNoPerspectiveIn:
1171             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1172             {
1173                 mOut << " [[center_no_perspective]]";
1174             }
1175             break;
1176 
1177         case TQualifier::EvqCentroidIn:
1178             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1179             {
1180                 mOut << " [[centroid_perspective]]";
1181             }
1182             break;
1183 
1184         case TQualifier::EvqSampleIn:
1185             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1186             {
1187                 mOut << " [[sample_perspective]]";
1188             }
1189             break;
1190 
1191         case TQualifier::EvqNoPerspectiveCentroidIn:
1192             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1193             {
1194                 mOut << " [[centroid_no_perspective]]";
1195             }
1196             break;
1197 
1198         case TQualifier::EvqNoPerspectiveSampleIn:
1199             if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1200             {
1201                 mOut << " [[sample_no_perspective]]";
1202             }
1203             break;
1204 
1205         case TQualifier::EvqFragColor:
1206             mOut << " [[color(0)]]";
1207             break;
1208 
1209         case TQualifier::EvqSecondaryFragColorEXT:
1210         case TQualifier::EvqSecondaryFragDataEXT:
1211             mOut << " [[color(0), index(1)]]";
1212             break;
1213 
1214         case TQualifier::EvqFragmentOut:
1215         case TQualifier::EvqFragmentInOut:
1216         case TQualifier::EvqFragData:
1217             if (mPipelineStructs.fragmentOut.external == &parent ||
1218                 mPipelineStructs.fragmentOut.externalExtra == &parent)
1219             {
1220                 if ((type.isVector() &&
1221                      (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1222                       basic == TBasicType::EbtFloat)) ||
1223                     qual == EvqFragData)
1224                 {
1225                     // The OpenGL ES 3.0 spec says locations must be specified
1226                     // unless there is only a single output, in which case the
1227                     // location is 0. So, when we get to this point the shader
1228                     // will have been rejected if locations are not specified
1229                     // and there is more than one output.
1230                     const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
1231                     if (layoutQualifier.locationsSpecified)
1232                     {
1233                         mOut << " [[color(" << layoutQualifier.location << ")";
1234                         ASSERT(layoutQualifier.index >= -1 && layoutQualifier.index <= 1);
1235                         if (layoutQualifier.index == 1)
1236                         {
1237                             mOut << ", index(1)";
1238                         }
1239                     }
1240                     else if (qual == EvqFragData)
1241                     {
1242                         mOut << " [[color(" << annotationIndices.color++ << ")";
1243                     }
1244                     else
1245                     {
1246                         // Either the only output or EXT_blend_func_extended is used;
1247                         // actual assignment will happen in UpdateFragmentShaderOutputs.
1248                         mOut << " [[" << sh::kUnassignedFragmentOutputString;
1249                     }
1250                     if (mRasterOrderGroupsSupported && qual == TQualifier::EvqFragmentInOut)
1251                     {
1252                         // Put fragment inouts in their own raster order group for better
1253                         // parallelism.
1254                         // NOTE: this is not required for the reads to be ordered and coherent.
1255                         // TODO(anglebug.com/40096838): Consider making raster order groups a PLS
1256                         // layout qualifier?
1257                         mOut << ", raster_order_group(0)";
1258                     }
1259                     mOut << "]]";
1260                 }
1261             }
1262             break;
1263 
1264         case TQualifier::EvqFragDepth:
1265             mOut << " [[depth(";
1266             switch (type.getLayoutQualifier().depth)
1267             {
1268                 case EdGreater:
1269                     mOut << "greater";
1270                     break;
1271                 case EdLess:
1272                     mOut << "less";
1273                     break;
1274                 default:
1275                     mOut << "any";
1276                     break;
1277             }
1278             mOut << "), function_constant(" << sh::mtl::kDepthWriteEnabledConstName << ")]]";
1279             break;
1280 
1281         case TQualifier::EvqSampleMask:
1282             if (field.symbolType() == SymbolType::AngleInternal)
1283             {
1284                 mOut << " [[sample_mask, function_constant("
1285                      << sh::mtl::kSampleMaskWriteEnabledConstName << ")]]";
1286             }
1287             break;
1288 
1289         default:
1290             break;
1291     }
1292 
1293     if (IsImage(type.getBasicType()))
1294     {
1295         if (type.getArraySizeProduct() != 1)
1296         {
1297             UNIMPLEMENTED();
1298         }
1299         mOut << " [[";
1300         if (type.getLayoutQualifier().rasterOrdered)
1301         {
1302             mOut << "raster_order_group(0), ";
1303         }
1304         mOut << "texture(" << mMainTextureIndex << ")]]";
1305         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1306         reflection->addRWTextureBinding(type.getLayoutQualifier().binding,
1307                                         static_cast<int>(mMainTextureIndex));
1308         ++mMainTextureIndex;
1309     }
1310 }
1311 
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1312 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1313     const TCompiler &compiler,
1314     const PipelineScoped<TStructure> &structure)
1315 {
1316     ASSERT(structure.isTotallyFull());
1317 
1318     const auto &shaderVars     = compiler.getAttributes();
1319     const size_t shaderVarSize = shaderVars.size();
1320     size_t shaderVarIndex      = 0;
1321 
1322     const auto &externalFields = structure.external->fields();
1323     const size_t externalSize  = externalFields.size();
1324     size_t externalIndex       = 0;
1325 
1326     const auto &internalFields = structure.internal->fields();
1327     const size_t internalSize  = internalFields.size();
1328     size_t internalIndex       = 0;
1329 
1330     // Internal fields are never split. External fields are sometimes split.
1331     ASSERT(externalSize >= internalSize);
1332 
1333     // Structures do not contain any inactive fields.
1334     ASSERT(shaderVarSize >= internalSize);
1335 
1336     std::map<Name, size_t> externalNameToAttributeIndex;
1337     size_t attributeIndex = 0;
1338 
1339     while (internalIndex < internalSize)
1340     {
1341         const TField &internalField = *internalFields[internalIndex];
1342         const Name internalName     = Name(internalField);
1343         const TType &internalType   = *internalField.type();
1344         while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1345                internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1346         {
1347             // This case represents an inactive field.
1348 
1349             ++shaderVarIndex;
1350             ASSERT(shaderVarIndex < shaderVarSize);
1351 
1352             ++attributeIndex;  // TODO: Might need to increment more if shader var type is a matrix.
1353         }
1354 
1355         const size_t cols =
1356             (internalType.isMatrix() && !externalFields[externalIndex]->type()->isMatrix())
1357                 ? internalType.getCols()
1358                 : 1;
1359 
1360         for (size_t c = 0; c < cols; ++c)
1361         {
1362             const TField &externalField = *externalFields[externalIndex];
1363             const Name externalName     = Name(externalField);
1364 
1365             externalNameToAttributeIndex[externalName] = attributeIndex;
1366 
1367             ++externalIndex;
1368             ++attributeIndex;
1369         }
1370 
1371         ++shaderVarIndex;
1372         ++internalIndex;
1373     }
1374 
1375     ASSERT(shaderVarIndex <= shaderVarSize);
1376     ASSERT(externalIndex <= externalSize);  // less than if padding was introduced
1377     ASSERT(internalIndex == internalSize);
1378 
1379     return externalNameToAttributeIndex;
1380 }
1381 
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1382 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1383                                                  FieldAnnotationIndices &annotationIndices)
1384 {
1385     EmitVariableDeclarationConfig evdConfig;
1386     evdConfig.disableStructSpecifier = true;
1387     emitVariableDeclaration(VarDecl(field), evdConfig);
1388     mOut << sh::kUnassignedAttributeString;
1389 }
1390 
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1391 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1392                                                      FieldAnnotationIndices &annotationIndices)
1393 {
1394     EmitVariableDeclarationConfig evdConfig;
1395     evdConfig.disableStructSpecifier = true;
1396     evdConfig.isUBO                  = mSymbolEnv.isUBO(field);
1397     evdConfig.isPointer              = mSymbolEnv.isPointer(field);
1398     emitVariableDeclaration(VarDecl(field), evdConfig);
1399     mOut << "[[id(" << annotationIndices.attribute << ")]]";
1400 
1401     const TType &type   = *field.type();
1402     const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1403 
1404     TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1405     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1406     const TStructure *structure    = type.getStruct();
1407     const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1408     reflection->addUniformBufferBinding(
1409         originalName,
1410         {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1411 
1412     annotationIndices.attribute += arraySize;
1413 }
1414 
emitStructDeclaration(const TType & type)1415 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1416 {
1417     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1418     ASSERT(type.isStructSpecifier());
1419 
1420     mOut << "struct ";
1421     emitBareTypeName(type, {});
1422 
1423     mOut << "\n";
1424     emitOpenBrace();
1425 
1426     const TStructure &structure = *type.getStruct();
1427     std::map<Name, size_t> fieldToAttributeIndex;
1428     const bool hasAttributeIndices      = mPipelineStructs.vertexIn.external == &structure;
1429     const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1430     const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1431 
1432     if (hasAttributeIndices)
1433     {
1434         // When attribute aliasing is supported, external attribute struct is filled post-link.
1435         if (mCompiler.supportsAttributeAliasing())
1436         {
1437             mtl::getTranslatorMetalReflection(&mCompiler)->hasAttributeAliasing = true;
1438             mOut << "@@Attrib-Bindings@@\n";
1439             emitCloseBrace();
1440             return;
1441         }
1442 
1443         fieldToAttributeIndex =
1444             BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1445     }
1446 
1447     FieldAnnotationIndices annotationIndices;
1448 
1449     for (const TField *field : structure.fields())
1450     {
1451         emitIndentation();
1452         if (hasAttributeIndices)
1453         {
1454             const auto it = fieldToAttributeIndex.find(Name(*field));
1455             if (it == fieldToAttributeIndex.end())
1456             {
1457                 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1458                 ASSERT(field->name().beginsWith("_"));
1459                 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1460                 emitFieldDeclaration(*field, structure, annotationIndices);
1461             }
1462             else
1463             {
1464                 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1465                        !field->name().beginsWith("_") ||
1466                        !angle::EndsWith(field->name().data(), "_pad"));
1467                 if (!reclaimUnusedAttributeIndices)
1468                 {
1469                     annotationIndices.attribute = it->second;
1470                 }
1471                 emitAttributeDeclaration(*field, annotationIndices);
1472             }
1473         }
1474         else if (hasUniformBufferIndicies)
1475         {
1476             emitUniformBufferDeclaration(*field, annotationIndices);
1477         }
1478         else
1479         {
1480             emitFieldDeclaration(*field, structure, annotationIndices);
1481         }
1482         mOut << ";\n";
1483     }
1484 
1485     emitCloseBrace();
1486 }
1487 
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1488 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1489     const VarDecl &decl,
1490     const EmitVariableDeclarationConfig &evdConfig)
1491 {
1492     EmitTypeConfig etConfig;
1493     etConfig.evdConfig = &evdConfig;
1494 
1495     const TType &type = decl.type();
1496     if (type.getQualifier() == TQualifier::EvqClipDistance)
1497     {
1498         // Clip distance output uses float[n] type instead of metal::array.
1499         // The element count is emitted after the post qualifier.
1500         ASSERT(type.getBasicType() == TBasicType::EbtFloat);
1501         mOut << "float";
1502     }
1503     else if (type.getQualifier() == TQualifier::EvqSampleID && evdConfig.isMainParameter)
1504     {
1505         // Metal's [[sample_id]] must be unsigned
1506         ASSERT(type.getBasicType() == TBasicType::EbtInt);
1507         mOut << "uint32_t";
1508     }
1509     else
1510     {
1511         emitType(type, etConfig);
1512     }
1513     if (decl.symbolType() != SymbolType::Empty)
1514     {
1515         mOut << " ";
1516         emitNameOf(decl);
1517     }
1518 }
1519 
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1520 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1521                                                 const EmitVariableDeclarationConfig &evdConfig)
1522 {
1523     const SymbolType symbolType = decl.symbolType();
1524     const TType &type           = decl.type();
1525     const TBasicType basicType  = type.getBasicType();
1526 
1527     switch (basicType)
1528     {
1529         case TBasicType::EbtStruct:
1530         {
1531             if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1532             {
1533                 // It's invalid to declare a struct inside a function argument. When emitting a
1534                 // function parameter, the callsite should set evdConfig.disableStructSpecifier.
1535                 ASSERT(!evdConfig.isParameter);
1536                 emitStructDeclaration(type);
1537                 if (symbolType != SymbolType::Empty)
1538                 {
1539                     mOut << " ";
1540                     emitNameOf(decl);
1541                 }
1542             }
1543             else
1544             {
1545                 emitOrdinaryVariableDeclaration(decl, evdConfig);
1546             }
1547         }
1548         break;
1549 
1550         default:
1551         {
1552             ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1553             emitOrdinaryVariableDeclaration(decl, evdConfig);
1554         }
1555     }
1556 
1557     if (evdConfig.emitPostQualifier)
1558     {
1559         emitPostQualifier(evdConfig, decl, type.getQualifier());
1560     }
1561 }
1562 
visitSymbol(TIntermSymbol * symbolNode)1563 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1564 {
1565     const TVariable &var = symbolNode->variable();
1566     const TType &type    = var.getType();
1567     ASSERT(var.symbolType() != SymbolType::Empty);
1568 
1569     if (type.getBasicType() == TBasicType::EbtVoid)
1570     {
1571         mOut << "/*";
1572         emitNameOf(var);
1573         mOut << "*/";
1574     }
1575     else
1576     {
1577         emitNameOf(var);
1578     }
1579 }
1580 
emitSingleConstant(const TConstantUnion * const constUnion)1581 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1582 {
1583     switch (constUnion->getType())
1584     {
1585         case TBasicType::EbtBool:
1586         {
1587             mOut << (constUnion->getBConst() ? "true" : "false");
1588         }
1589         break;
1590 
1591         case TBasicType::EbtFloat:
1592         {
1593             float value = constUnion->getFConst();
1594             if (std::isnan(value))
1595             {
1596                 mOut << "NAN";
1597             }
1598             else if (std::isinf(value))
1599             {
1600                 if (value < 0)
1601                 {
1602                     mOut << "-";
1603                 }
1604                 mOut << "INFINITY";
1605             }
1606             else
1607             {
1608                 mOut << value << "f";
1609             }
1610         }
1611         break;
1612 
1613         case TBasicType::EbtInt:
1614         {
1615             mOut << constUnion->getIConst();
1616         }
1617         break;
1618 
1619         case TBasicType::EbtUInt:
1620         {
1621             mOut << constUnion->getUConst() << "u";
1622         }
1623         break;
1624 
1625         default:
1626         {
1627             UNIMPLEMENTED();
1628         }
1629     }
1630 }
1631 
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1632 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1633     const TConstantUnion *const constUnion,
1634     const size_t size)
1635 {
1636     const TConstantUnion *constUnionIterated = constUnion;
1637     for (size_t i = 0; i < size; i++, constUnionIterated++)
1638     {
1639         emitSingleConstant(constUnionIterated);
1640 
1641         if (i != size - 1)
1642         {
1643             mOut << ", ";
1644         }
1645     }
1646     return constUnionIterated;
1647 }
1648 
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1649 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1650                                                            const TConstantUnion *constUnionBegin)
1651 {
1652     const TConstantUnion *constUnionCurr = constUnionBegin;
1653     const TStructure *structure          = type.getStruct();
1654     if (structure)
1655     {
1656         EmitTypeConfig config = EmitTypeConfig{nullptr};
1657         emitType(type, config);
1658         mOut << "{";
1659         const TFieldList &fields = structure->fields();
1660         for (size_t i = 0; i < fields.size(); ++i)
1661         {
1662             const TType *fieldType = fields[i]->type();
1663             constUnionCurr         = emitConstantUnion(*fieldType, constUnionCurr);
1664             if (i != fields.size() - 1)
1665             {
1666                 mOut << ", ";
1667             }
1668         }
1669         mOut << "}";
1670     }
1671     else
1672     {
1673         size_t size    = type.getObjectSize();
1674         bool writeType = size > 1;
1675         if (writeType)
1676         {
1677             EmitTypeConfig config = EmitTypeConfig{nullptr};
1678             emitType(type, config);
1679             mOut << "(";
1680         }
1681         constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1682         if (writeType)
1683         {
1684             mOut << ")";
1685         }
1686     }
1687     return constUnionCurr;
1688 }
1689 
visitConstantUnion(TIntermConstantUnion * constValueNode)1690 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1691 {
1692     emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1693 }
1694 
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1695 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1696 {
1697     groupedTraverse(*swizzleNode->getOperand());
1698     mOut << ".";
1699 
1700     {
1701 #if defined(ANGLE_ENABLE_ASSERTS)
1702         DebugSink::EscapedSink escapedOut(mOut.escape());
1703         TInfoSinkBase &out = escapedOut.get();
1704 #else
1705         TInfoSinkBase &out = mOut;
1706 #endif
1707         swizzleNode->writeOffsetsAsXYZW(&out);
1708     }
1709 
1710     return false;
1711 }
1712 
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1713 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1714                                                 const TConstantUnion &index)
1715 {
1716     ASSERT(index.getType() == TBasicType::EbtInt);
1717 
1718     const TFieldList &fieldList = fieldListCollection.fields();
1719     const int indexVal          = index.getIConst();
1720     const TField &field         = *fieldList[indexVal];
1721 
1722     return field;
1723 }
1724 
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1725 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1726                                                 TIntermTyped &indexNode)
1727 {
1728     const TType &fieldsType = fieldsNode.getType();
1729 
1730     const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1731     if (fieldListCollection == nullptr)
1732     {
1733         fieldListCollection = fieldsType.getInterfaceBlock();
1734     }
1735     ASSERT(fieldListCollection);
1736 
1737     const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1738     ASSERT(indexNode_);
1739     const TConstantUnion &index = *indexNode_->getConstantValue();
1740 
1741     return getDirectField(*fieldListCollection, index);
1742 }
1743 
visitBinary(Visit,TIntermBinary * binaryNode)1744 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1745 {
1746     const TOperator op      = binaryNode->getOp();
1747     TIntermTyped &leftNode  = *binaryNode->getLeft();
1748     TIntermTyped &rightNode = *binaryNode->getRight();
1749 
1750     switch (op)
1751     {
1752         case TOperator::EOpIndexDirectStruct:
1753         case TOperator::EOpIndexDirectInterfaceBlock:
1754         {
1755             const TField &field = getDirectField(leftNode, rightNode);
1756             if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1757             {
1758                 emitOpeningPointerParen();
1759             }
1760             groupedTraverse(leftNode);
1761             if (!mSymbolEnv.isPointer(field))
1762             {
1763                 emitClosingPointerParen();
1764             }
1765             mOut << ".";
1766             emitNameOf(field);
1767         }
1768         break;
1769 
1770         case TOperator::EOpIndexDirect:
1771         case TOperator::EOpIndexIndirect:
1772         {
1773             TType leftType = leftNode.getType();
1774             groupedTraverse(leftNode);
1775             mOut << "[";
1776             const TConstantUnion *constIndex = rightNode.getConstantValue();
1777             // TODO(anglebug.com/42266914): Convert type and bound checks to
1778             // assertions after AST validation is enabled for MSL translation.
1779             if (!leftType.isUnsizedArray() && constIndex != nullptr &&
1780                 constIndex->getType() == EbtInt && constIndex->getIConst() >= 0 &&
1781                 constIndex->getIConst() < static_cast<int>(leftType.isArray()
1782                                                                ? leftType.getOutermostArraySize()
1783                                                                : leftType.getNominalSize()))
1784             {
1785                 emitSingleConstant(constIndex);
1786             }
1787             else
1788             {
1789                 mOut << "ANGLE_int_clamp(";
1790                 groupedTraverse(rightNode);
1791                 mOut << ", 0, ";
1792                 if (leftType.isUnsizedArray())
1793                 {
1794                     groupedTraverse(leftNode);
1795                     mOut << ".size()";
1796                 }
1797                 else
1798                 {
1799                     uint32_t maxSize;
1800                     if (leftType.isArray())
1801                     {
1802                         maxSize = leftType.getOutermostArraySize() - 1;
1803                     }
1804                     else
1805                     {
1806                         maxSize = leftType.getNominalSize() - 1;
1807                     }
1808                     mOut << maxSize;
1809                 }
1810                 mOut << ")";
1811             }
1812             mOut << "]";
1813         }
1814         break;
1815 
1816         default:
1817         {
1818             const TType &resultType = binaryNode->getType();
1819             const TType &leftType   = leftNode.getType();
1820             const TType &rightType  = rightNode.getType();
1821 
1822             if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1823             {
1824                 groupedTraverse(leftNode);
1825                 if (op != TOperator::EOpComma)
1826                 {
1827                     mOut << " ";
1828                 }
1829                 else
1830                 {
1831                     emitClosingPointerParen();
1832                 }
1833                 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1834                 groupedTraverse(rightNode);
1835             }
1836             else
1837             {
1838                 emitClosingPointerParen();
1839                 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1840                 leftNode.traverse(this);
1841                 mOut << ", ";
1842                 rightNode.traverse(this);
1843                 mOut << ")";
1844             }
1845         }
1846     }
1847 
1848     return false;
1849 }
1850 
IsPostfix(TOperator op)1851 static bool IsPostfix(TOperator op)
1852 {
1853     switch (op)
1854     {
1855         case TOperator::EOpPostIncrement:
1856         case TOperator::EOpPostDecrement:
1857             return true;
1858 
1859         default:
1860             return false;
1861     }
1862 }
1863 
visitUnary(Visit,TIntermUnary * unaryNode)1864 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1865 {
1866     const TOperator op      = unaryNode->getOp();
1867     const TType &resultType = unaryNode->getType();
1868 
1869     TIntermTyped &arg    = *unaryNode->getOperand();
1870     const TType &argType = arg.getType();
1871 
1872     if (op == TOperator::EOpIsnan || op == TOperator::EOpIsinf)
1873     {
1874         mtl::getTranslatorMetalReflection(&mCompiler)->hasIsnanOrIsinf = true;
1875     }
1876 
1877     const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1878 
1879     if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1880     {
1881         const bool postfix = IsPostfix(op);
1882         if (!postfix)
1883         {
1884             mOut << name;
1885         }
1886         groupedTraverse(arg);
1887         if (postfix)
1888         {
1889             mOut << name;
1890         }
1891     }
1892     else
1893     {
1894         mOut << name << "(";
1895         arg.traverse(this);
1896         mOut << ")";
1897     }
1898 
1899     return false;
1900 }
1901 
visitTernary(Visit,TIntermTernary * conditionalNode)1902 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1903 {
1904     groupedTraverse(*conditionalNode->getCondition());
1905     mOut << " ? ";
1906     groupedTraverse(*conditionalNode->getTrueExpression());
1907     mOut << " : ";
1908     groupedTraverse(*conditionalNode->getFalseExpression());
1909 
1910     return false;
1911 }
1912 
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1913 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1914 {
1915     TIntermTyped &condNode = *ifThenElseNode->getCondition();
1916     TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1917     TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1918 
1919     emitIndentation();
1920     mOut << "if (";
1921     condNode.traverse(this);
1922     mOut << ")";
1923 
1924     if (thenNode)
1925     {
1926         mOut << "\n";
1927         thenNode->traverse(this);
1928     }
1929     else
1930     {
1931         mOut << " {}";
1932     }
1933 
1934     if (elseNode)
1935     {
1936         mOut << "\n";
1937         emitIndentation();
1938         mOut << "else\n";
1939         elseNode->traverse(this);
1940     }
1941     else
1942     {
1943         // Always emit "else" even when empty block to avoid nested if-stmt issues.
1944         mOut << " else {}";
1945     }
1946 
1947     return false;
1948 }
1949 
visitSwitch(Visit,TIntermSwitch * switchNode)1950 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1951 {
1952     emitIndentation();
1953     mOut << "switch (";
1954     switchNode->getInit()->traverse(this);
1955     mOut << ")\n";
1956 
1957     ASSERT(!mParentIsSwitch);
1958     mParentIsSwitch = true;
1959     switchNode->getStatementList()->traverse(this);
1960     mParentIsSwitch = false;
1961 
1962     return false;
1963 }
1964 
visitCase(Visit,TIntermCase * caseNode)1965 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1966 {
1967     emitIndentation();
1968 
1969     if (caseNode->hasCondition())
1970     {
1971         TIntermTyped *condExpr = caseNode->getCondition();
1972         mOut << "case ";
1973         condExpr->traverse(this);
1974         mOut << ":";
1975     }
1976     else
1977     {
1978         mOut << "default:\n";
1979     }
1980 
1981     return false;
1982 }
1983 
emitFunctionSignature(const TFunction & func)1984 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1985 {
1986     const bool isMain = func.isMain();
1987 
1988     emitFunctionReturn(func);
1989 
1990     mOut << " ";
1991     emitNameOf(func);
1992     if (isMain)
1993     {
1994         mOut << "0";
1995     }
1996     mOut << "(";
1997 
1998     bool emitComma          = false;
1999     const size_t paramCount = func.getParamCount();
2000     for (size_t i = 0; i < paramCount; ++i)
2001     {
2002         if (emitComma)
2003         {
2004             mOut << ", ";
2005         }
2006         emitComma = true;
2007 
2008         const TVariable &param = *func.getParam(i);
2009         emitFunctionParameter(func, param);
2010     }
2011 
2012     if (isTraversingVertexMain)
2013     {
2014         mOut << " @@XFB-Bindings@@ ";
2015     }
2016 
2017     mOut << ")";
2018 }
2019 
emitFunctionReturn(const TFunction & func)2020 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
2021 {
2022     const bool isMain       = func.isMain();
2023     bool isVertexMain       = false;
2024     const TType &returnType = func.getReturnType();
2025     if (isMain)
2026     {
2027         const TStructure *structure = returnType.getStruct();
2028         if (mPipelineStructs.fragmentOut.matches(*structure))
2029         {
2030             if (mCompiler.isEarlyFragmentTestsSpecified())
2031             {
2032                 mOut << "[[early_fragment_tests]]\n";
2033             }
2034             mOut << "fragment ";
2035         }
2036         else if (mPipelineStructs.vertexOut.matches(*structure))
2037         {
2038             ASSERT(structure != nullptr);
2039             mOut << "vertex __VERTEX_OUT(";
2040             isVertexMain = true;
2041         }
2042         else
2043         {
2044             UNIMPLEMENTED();
2045         }
2046     }
2047     emitType(returnType, EmitTypeConfig());
2048     if (isVertexMain)
2049         mOut << ") ";
2050 }
2051 
emitFunctionParameter(const TFunction & func,const TVariable & param)2052 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable &param)
2053 {
2054     const bool isMain = func.isMain();
2055 
2056     const TType &type           = param.getType();
2057     const TStructure *structure = type.getStruct();
2058 
2059     EmitVariableDeclarationConfig evdConfig;
2060     evdConfig.isParameter            = true;
2061     evdConfig.disableStructSpecifier = true;  // It's invalid to declare a struct in a function arg.
2062     evdConfig.isMainParameter        = isMain;
2063     evdConfig.emitPostQualifier      = isMain;
2064     evdConfig.isUBO                  = mSymbolEnv.isUBO(param);
2065     evdConfig.isPointer              = mSymbolEnv.isPointer(param);
2066     evdConfig.isReference            = mSymbolEnv.isReference(param);
2067     emitVariableDeclaration(VarDecl(param), evdConfig);
2068 
2069     if (isMain)
2070     {
2071         TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
2072         if (structure)
2073         {
2074             if (mPipelineStructs.fragmentIn.matches(*structure) ||
2075                 mPipelineStructs.vertexIn.matches(*structure))
2076             {
2077                 mOut << " [[stage_in]]";
2078             }
2079             else if (mPipelineStructs.angleUniforms.matches(*structure))
2080             {
2081                 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
2082             }
2083             else if (mPipelineStructs.uniformBuffers.matches(*structure))
2084             {
2085                 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
2086                 reflection->hasUBOs = true;
2087             }
2088             else if (mPipelineStructs.userUniforms.matches(*structure))
2089             {
2090                 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
2091                 reflection->addUserUniformBufferBinding(param.name().data(),
2092                                                         mMainUniformBufferIndex);
2093                 mMainUniformBufferIndex += type.getArraySizeProduct();
2094             }
2095             else if (structure->name() == "metal::sampler")
2096             {
2097                 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
2098                 const std::string originalName =
2099                     reflection->getOriginalName(param.uniqueId().get());
2100                 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
2101                 mMainSamplerIndex += type.getArraySizeProduct();
2102             }
2103         }
2104         else if (IsSampler(type.getBasicType()))
2105         {
2106             mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
2107             const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
2108             reflection->addTextureBinding(originalName, mMainTextureIndex);
2109             mMainTextureIndex += type.getArraySizeProduct();
2110         }
2111         else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
2112                                     Pipeline::Variant::Modified))
2113         {
2114             mOut << " [[instance_id]]";
2115         }
2116         else if (Name(param) == kBaseInstanceName)
2117         {
2118             mOut << " [[base_instance]]";
2119         }
2120     }
2121 }
2122 
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)2123 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
2124 {
2125     const TFunction &func = *funcProtoNode->getFunction();
2126 
2127     emitIndentation();
2128     emitFunctionSignature(func);
2129 }
2130 
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)2131 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
2132 {
2133     const TFunction &func = *funcDefNode->getFunction();
2134     TIntermBlock &body    = *funcDefNode->getBody();
2135     if (func.isMain())
2136     {
2137         const TType &returnType     = func.getReturnType();
2138         const TStructure *structure = returnType.getStruct();
2139         isTraversingVertexMain      = (mPipelineStructs.vertexOut.matches(*structure)) &&
2140                                  mCompiler.getShaderType() == GL_VERTEX_SHADER;
2141     }
2142     emitIndentation();
2143     emitFunctionSignature(func);
2144     mOut << "\n";
2145     body.traverse(this);
2146     if (isTraversingVertexMain)
2147     {
2148         isTraversingVertexMain = false;
2149     }
2150     return false;
2151 }
2152 
BuildFuncToName()2153 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
2154 {
2155     FuncToName map;
2156 
2157     auto putAngle = [&](const char *nameStr) {
2158         const ImmutableString name(nameStr);
2159         ASSERT(map.find(name) == map.end());
2160         map[name] = Name(nameStr, SymbolType::AngleInternal);
2161     };
2162 
2163     putAngle("texelFetch");
2164     putAngle("texelFetchOffset");
2165     putAngle("texture");
2166     putAngle("texture2D");
2167     putAngle("texture2DGradEXT");
2168     putAngle("texture2DLod");
2169     putAngle("texture2DLodEXT");
2170     putAngle("texture2DProj");
2171     putAngle("texture2DProjGradEXT");
2172     putAngle("texture2DProjLod");
2173     putAngle("texture2DProjLodEXT");
2174     putAngle("texture3D");
2175     putAngle("texture3DLod");
2176     putAngle("texture3DProj");
2177     putAngle("texture3DProjLod");
2178     putAngle("textureCube");
2179     putAngle("textureCubeGradEXT");
2180     putAngle("textureCubeLod");
2181     putAngle("textureCubeLodEXT");
2182     putAngle("textureGrad");
2183     putAngle("textureGradOffset");
2184     putAngle("textureLod");
2185     putAngle("textureLodOffset");
2186     putAngle("textureOffset");
2187     putAngle("textureProj");
2188     putAngle("textureProjGrad");
2189     putAngle("textureProjGradOffset");
2190     putAngle("textureProjLod");
2191     putAngle("textureProjLodOffset");
2192     putAngle("textureProjOffset");
2193     putAngle("textureSize");
2194     putAngle("imageLoad");
2195     putAngle("imageStore");
2196     putAngle("memoryBarrierImage");
2197 
2198     return map;
2199 }
2200 
visitAggregate(Visit,TIntermAggregate * aggregateNode)2201 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
2202 {
2203     const TIntermSequence &args = *aggregateNode->getSequence();
2204 
2205     auto emitArgList = [&](const char *open, const char *close) {
2206         mOut << open;
2207 
2208         bool emitComma = false;
2209         for (TIntermNode *arg : args)
2210         {
2211             if (emitComma)
2212             {
2213                 emitClosingPointerParen();
2214                 mOut << ", ";
2215             }
2216             emitComma = true;
2217             arg->traverse(this);
2218         }
2219 
2220         mOut << close;
2221     };
2222 
2223     const TType &retType = aggregateNode->getType();
2224 
2225     if (aggregateNode->isConstructor())
2226     {
2227         const bool isStandalone = getParentNode()->getAsBlock();
2228         if (isStandalone)
2229         {
2230             // Prevent constructor from being interpreted as a declaration by wrapping in parens.
2231             // This can happen if given something like:
2232             //      int(symbol); // <- This will be treated like `int symbol;`... don't want that.
2233             // So instead emit:
2234             //      (int(symbol));
2235             mOut << "(";
2236         }
2237 
2238         const EmitTypeConfig etConfig;
2239 
2240         if (retType.isArray())
2241         {
2242             emitType(retType, etConfig);
2243             emitArgList("{", "}");
2244         }
2245         else if (retType.getStruct())
2246         {
2247             emitType(retType, etConfig);
2248             emitArgList("{", "}");
2249         }
2250         else
2251         {
2252             emitType(retType, etConfig);
2253             emitArgList("(", ")");
2254         }
2255 
2256         if (isStandalone)
2257         {
2258             mOut << ")";
2259         }
2260 
2261         return false;
2262     }
2263     else
2264     {
2265         const TOperator op = aggregateNode->getOp();
2266         switch (op)
2267         {
2268             case TOperator::EOpCallFunctionInAST:
2269             case TOperator::EOpCallInternalRawFunction:
2270             {
2271                 const TFunction &func = *aggregateNode->getFunction();
2272                 emitNameOf(func);
2273                 //'@' symbol in name specifices a macro substitution marker.
2274                 if (!func.name().contains("@"))
2275                 {
2276                     emitArgList("(", ")");
2277                 }
2278                 else
2279                 {
2280                     mTemporarilyDisableSemicolon =
2281                         true;  // Disable semicolon for macro substitution.
2282                 }
2283                 return false;
2284             }
2285 
2286             default:
2287             {
2288                 auto getArgType = [&](size_t index) -> const TType * {
2289                     if (index < args.size())
2290                     {
2291                         TIntermTyped *arg = args[index]->getAsTyped();
2292                         ASSERT(arg);
2293                         return &arg->getType();
2294                     }
2295                     return nullptr;
2296                 };
2297 
2298                 const TType *argType0 = getArgType(0);
2299                 const TType *argType1 = getArgType(1);
2300                 const TType *argType2 = getArgType(2);
2301 
2302                 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
2303 
2304                 if (IsSymbolicOperator(op, retType, argType0, argType1))
2305                 {
2306                     switch (args.size())
2307                     {
2308                         case 1:
2309                         {
2310                             TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2311                             if (IsPostfix(op))
2312                             {
2313                                 mOut << opName;
2314                                 groupedTraverse(operandNode);
2315                             }
2316                             else
2317                             {
2318                                 groupedTraverse(operandNode);
2319                                 mOut << opName;
2320                             }
2321                             return false;
2322                         }
2323 
2324                         case 2:
2325                         {
2326                             TIntermNode &leftNode  = *aggregateNode->getChildNode(0);
2327                             TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2328                             groupedTraverse(leftNode);
2329                             mOut << " " << opName << " ";
2330                             groupedTraverse(rightNode);
2331                             return false;
2332                         }
2333 
2334                         default:
2335                             UNREACHABLE();
2336                             return false;
2337                     }
2338                 }
2339                 else if (opName == nullptr)
2340                 {
2341                     const TFunction &func = *aggregateNode->getFunction();
2342                     auto it               = mFuncToName.find(func.name());
2343                     ASSERT(it != mFuncToName.end());
2344                     EmitName(mOut, it->second);
2345                     emitArgList("(", ")");
2346                     return false;
2347                 }
2348                 else
2349                 {
2350                     mOut << opName;
2351                     emitArgList("(", ")");
2352                     return false;
2353                 }
2354             }
2355         }
2356     }
2357 }
2358 
emitOpenBrace()2359 void GenMetalTraverser::emitOpenBrace()
2360 {
2361     ASSERT(mIndentLevel >= 0);
2362 
2363     emitIndentation();
2364     mOut << "{\n";
2365     ++mIndentLevel;
2366 }
2367 
emitCloseBrace()2368 void GenMetalTraverser::emitCloseBrace()
2369 {
2370     ASSERT(mIndentLevel >= 1);
2371 
2372     --mIndentLevel;
2373     emitIndentation();
2374     mOut << "}";
2375 }
2376 
RequiresSemicolonTerminator(TIntermNode & node)2377 static bool RequiresSemicolonTerminator(TIntermNode &node)
2378 {
2379     if (node.getAsBlock())
2380     {
2381         return false;
2382     }
2383     if (node.getAsLoopNode())
2384     {
2385         return false;
2386     }
2387     if (node.getAsSwitchNode())
2388     {
2389         return false;
2390     }
2391     if (node.getAsIfElseNode())
2392     {
2393         return false;
2394     }
2395     if (node.getAsFunctionDefinition())
2396     {
2397         return false;
2398     }
2399     if (node.getAsCaseNode())
2400     {
2401         return false;
2402     }
2403 
2404     return true;
2405 }
2406 
NewlinePad(TIntermNode & node)2407 static bool NewlinePad(TIntermNode &node)
2408 {
2409     if (node.getAsFunctionDefinition())
2410     {
2411         return true;
2412     }
2413     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2414     {
2415         ASSERT(declNode->getChildCount() == 1);
2416         TIntermNode &childNode = *declNode->getChildNode(0);
2417         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2418         {
2419             const TVariable &var = symbolNode->variable();
2420             return var.getType().isStructSpecifier();
2421         }
2422         return false;
2423     }
2424     return false;
2425 }
2426 
visitBlock(Visit,TIntermBlock * blockNode)2427 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2428 {
2429     ASSERT(mIndentLevel >= -1);
2430     const bool isGlobalScope  = mIndentLevel == -1;
2431     const bool parentIsSwitch = mParentIsSwitch;
2432     mParentIsSwitch           = false;
2433 
2434     if (isGlobalScope)
2435     {
2436         ++mIndentLevel;
2437     }
2438     else
2439     {
2440         emitOpenBrace();
2441         if (parentIsSwitch)
2442         {
2443             ++mIndentLevel;
2444         }
2445     }
2446 
2447     TIntermNode *prevStmtNode = nullptr;
2448 
2449     const size_t stmtCount = blockNode->getChildCount();
2450     for (size_t i = 0; i < stmtCount; ++i)
2451     {
2452         TIntermNode &stmtNode = *blockNode->getChildNode(i);
2453 
2454         if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2455         {
2456             mOut << "\n";
2457         }
2458         const bool isCase = stmtNode.getAsCaseNode();
2459         mIndentLevel -= isCase;
2460         emitIndentation();
2461         mIndentLevel += isCase;
2462         stmtNode.traverse(this);
2463         if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2464         {
2465             mOut << ";";
2466         }
2467         mTemporarilyDisableSemicolon = false;
2468         mOut << "\n";
2469 
2470         prevStmtNode = &stmtNode;
2471     }
2472 
2473     if (isGlobalScope)
2474     {
2475         ASSERT(mIndentLevel == 0);
2476         --mIndentLevel;
2477     }
2478     else
2479     {
2480         if (parentIsSwitch)
2481         {
2482             ASSERT(mIndentLevel >= 1);
2483             --mIndentLevel;
2484         }
2485         emitCloseBrace();
2486         mParentIsSwitch = parentIsSwitch;
2487     }
2488 
2489     return false;
2490 }
2491 
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2492 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2493 {
2494     return false;
2495 }
2496 
visitDeclaration(Visit,TIntermDeclaration * declNode)2497 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2498 {
2499     ASSERT(declNode->getChildCount() == 1);
2500     TIntermNode &node = *declNode->getChildNode(0);
2501 
2502     EmitVariableDeclarationConfig evdConfig;
2503 
2504     if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2505     {
2506         const TVariable &var = symbolNode->variable();
2507         emitVariableDeclaration(VarDecl(var), evdConfig);
2508         if (var.getType().isArray() && var.getType().getQualifier() == EvqTemporary)
2509         {
2510             // The translator frontend injects a loop-based init for user arrays when the source
2511             // shader is using ESSL 1.00. Some Metal drivers may fail to access elements of such
2512             // arrays at runtime depending on the array size. An empty literal initializer added
2513             // to the generated MSL bypasses the issue. The frontend may be further optimized to
2514             // skip the loop-based init when targeting MSL.
2515             mOut << "{}";
2516         }
2517     }
2518     else if (TIntermBinary *initNode = node.getAsBinaryNode())
2519     {
2520         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2521         TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2522         TIntermTyped *valueNode       = initNode->getRight()->getAsTyped();
2523         ASSERT(leftSymbolNode && valueNode);
2524 
2525         if (getRootNode() == getParentBlock())
2526         {
2527             // DeferGlobalInitializers should have turned non-const global initializers into
2528             // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2529             // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2530             // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2531             // DeferGlobalInitializers.
2532             mOut << "constant ";
2533         }
2534 
2535         const TVariable &var = leftSymbolNode->variable();
2536         const Name varName(var);
2537 
2538         if (ExpressionContainsName(varName, *valueNode))
2539         {
2540             mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2541         }
2542 
2543         emitVariableDeclaration(VarDecl(var), evdConfig);
2544         mOut << " = ";
2545         groupedTraverse(*valueNode);
2546     }
2547     else
2548     {
2549         UNREACHABLE();
2550     }
2551 
2552     return false;
2553 }
2554 
visitLoop(Visit,TIntermLoop * loopNode)2555 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2556 {
2557     const TLoopType loopType = loopNode->getType();
2558 
2559     switch (loopType)
2560     {
2561         case TLoopType::ELoopFor:
2562             return visitForLoop(loopNode);
2563         case TLoopType::ELoopWhile:
2564             return visitWhileLoop(loopNode);
2565         case TLoopType::ELoopDoWhile:
2566             return visitDoWhileLoop(loopNode);
2567     }
2568 }
2569 
visitForLoop(TIntermLoop * loopNode)2570 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2571 {
2572     ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2573 
2574     TIntermNode *initNode  = loopNode->getInit();
2575     TIntermTyped *condNode = loopNode->getCondition();
2576     TIntermTyped *exprNode = loopNode->getExpression();
2577 
2578     mOut << "for (";
2579 
2580     if (initNode)
2581     {
2582         initNode->traverse(this);
2583     }
2584     else
2585     {
2586         mOut << " ";
2587     }
2588 
2589     mOut << "; ";
2590 
2591     if (condNode)
2592     {
2593         condNode->traverse(this);
2594     }
2595 
2596     mOut << "; ";
2597 
2598     if (exprNode)
2599     {
2600         exprNode->traverse(this);
2601     }
2602 
2603     mOut << ")\n";
2604 
2605     emitLoopBody(loopNode->getBody());
2606 
2607     return false;
2608 }
2609 
visitWhileLoop(TIntermLoop * loopNode)2610 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2611 {
2612     ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2613 
2614     TIntermNode *initNode  = loopNode->getInit();
2615     TIntermTyped *condNode = loopNode->getCondition();
2616     TIntermTyped *exprNode = loopNode->getExpression();
2617     ASSERT(condNode);
2618     ASSERT(!initNode && !exprNode);
2619 
2620     emitIndentation();
2621     mOut << "while (";
2622     condNode->traverse(this);
2623     mOut << ")\n";
2624     emitLoopBody(loopNode->getBody());
2625 
2626     return false;
2627 }
2628 
visitDoWhileLoop(TIntermLoop * loopNode)2629 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2630 {
2631     ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2632 
2633     TIntermNode *initNode  = loopNode->getInit();
2634     TIntermTyped *condNode = loopNode->getCondition();
2635     TIntermTyped *exprNode = loopNode->getExpression();
2636     ASSERT(condNode);
2637     ASSERT(!initNode && !exprNode);
2638 
2639     emitIndentation();
2640     mOut << "do\n";
2641     emitLoopBody(loopNode->getBody());
2642     mOut << "\n";
2643     emitIndentation();
2644     mOut << "while (";
2645     condNode->traverse(this);
2646     mOut << ");";
2647 
2648     return false;
2649 }
2650 
visitBranch(Visit,TIntermBranch * branchNode)2651 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2652 {
2653     const TOperator flowOp = branchNode->getFlowOp();
2654     TIntermTyped *exprNode = branchNode->getExpression();
2655 
2656     emitIndentation();
2657 
2658     switch (flowOp)
2659     {
2660         case TOperator::EOpKill:
2661         {
2662             ASSERT(exprNode == nullptr);
2663             mOut << "metal::discard_fragment()";
2664         }
2665         break;
2666 
2667         case TOperator::EOpReturn:
2668         {
2669             if (isTraversingVertexMain)
2670             {
2671                 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2672                 emitIndentation();
2673                 mOut << "return;\n";
2674                 emitIndentation();
2675                 mOut << "#else\n";
2676                 emitIndentation();
2677             }
2678             mOut << "return";
2679             if (exprNode)
2680             {
2681                 mOut << " ";
2682                 exprNode->traverse(this);
2683                 mOut << ";";
2684             }
2685             if (isTraversingVertexMain)
2686             {
2687                 mOut << "\n";
2688                 emitIndentation();
2689                 mOut << "#endif\n";
2690                 mTemporarilyDisableSemicolon = true;
2691             }
2692         }
2693         break;
2694 
2695         case TOperator::EOpBreak:
2696         {
2697             ASSERT(exprNode == nullptr);
2698             mOut << "break";
2699         }
2700         break;
2701 
2702         case TOperator::EOpContinue:
2703         {
2704             ASSERT(exprNode == nullptr);
2705             mOut << "continue";
2706         }
2707         break;
2708 
2709         default:
2710         {
2711             UNREACHABLE();
2712         }
2713     }
2714 
2715     return false;
2716 }
2717 
2718 static size_t emitMetalCallCount = 0;
2719 
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,const ShCompileOptions & compileOptions)2720 bool sh::EmitMetal(TCompiler &compiler,
2721                    TIntermBlock &root,
2722                    IdGen &idGen,
2723                    const PipelineStructs &pipelineStructs,
2724                    SymbolEnv &symbolEnv,
2725                    const ProgramPreludeConfig &ppc,
2726                    const ShCompileOptions &compileOptions)
2727 {
2728     TInfoSinkBase &out = compiler.getInfoSink().obj;
2729 
2730     {
2731         ++emitMetalCallCount;
2732         std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2733         if (!filenameProto.empty())
2734         {
2735             if (filenameProto != "/dev/null")
2736             {
2737                 auto tryOpen = [&](char const *ext) {
2738                     auto filename = filenameProto;
2739                     filename += std::to_string(emitMetalCallCount);
2740                     filename += ".";
2741                     filename += ext;
2742                     return fopen(filename.c_str(), "rb");
2743                 };
2744                 FILE *file = tryOpen("metal");
2745                 if (!file)
2746                 {
2747                     file = tryOpen("cpp");
2748                 }
2749                 ASSERT(file);
2750 
2751                 fseek(file, 0, SEEK_END);
2752                 size_t fileSize = ftell(file);
2753                 fseek(file, 0, SEEK_SET);
2754 
2755                 std::vector<char> buff;
2756                 buff.resize(fileSize + 1);
2757                 fread(buff.data(), fileSize, 1, file);
2758                 buff.back() = '\0';
2759 
2760                 fclose(file);
2761 
2762                 out << buff.data();
2763             }
2764 
2765             return true;
2766         }
2767     }
2768 
2769     out << "\n\n";
2770 
2771     if (!EmitProgramPrelude(root, out, ppc))
2772     {
2773         return false;
2774     }
2775 
2776     {
2777 #if defined(ANGLE_ENABLE_ASSERTS)
2778         DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2779         outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2780 #else
2781         TInfoSinkBase &outWrapper = out;
2782 #endif
2783         GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, symbolEnv,
2784                               compileOptions);
2785         root.traverse(&gen);
2786     }
2787 
2788     out << "\n";
2789 
2790     return true;
2791 }
2792