1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include <cctype>
8 #include <map>
9
10 #include "common/system_utils.h"
11 #include "compiler/translator/BaseTypes.h"
12 #include "compiler/translator/ImmutableStringBuilder.h"
13 #include "compiler/translator/Name.h"
14 #include "compiler/translator/OutputTree.h"
15 #include "compiler/translator/SymbolTable.h"
16 #include "compiler/translator/msl/AstHelpers.h"
17 #include "compiler/translator/msl/DebugSink.h"
18 #include "compiler/translator/msl/EmitMetal.h"
19 #include "compiler/translator/msl/Layout.h"
20 #include "compiler/translator/msl/ProgramPrelude.h"
21 #include "compiler/translator/msl/RewritePipelines.h"
22 #include "compiler/translator/msl/TranslatorMSL.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24
25 using namespace sh;
26
27 ////////////////////////////////////////////////////////////////////////////////
28
29 #if defined(ANGLE_ENABLE_ASSERTS)
30 using Sink = DebugSink;
31 #else
32 using Sink = TInfoSinkBase;
33 #endif
34
35 ////////////////////////////////////////////////////////////////////////////////
36
37 namespace
38 {
39
40 struct VarDecl
41 {
VarDecl__anon45ec70e50111::VarDecl42 explicit VarDecl(const TVariable &var) : mVariable(&var), mIsField(false) {}
VarDecl__anon45ec70e50111::VarDecl43 explicit VarDecl(const TField &field) : mField(&field), mIsField(true) {}
44
variable__anon45ec70e50111::VarDecl45 ANGLE_INLINE const TVariable &variable() const
46 {
47 ASSERT(isVariable());
48 return *mVariable;
49 }
50
field__anon45ec70e50111::VarDecl51 ANGLE_INLINE const TField &field() const
52 {
53 ASSERT(isField());
54 return *mField;
55 }
56
isVariable__anon45ec70e50111::VarDecl57 ANGLE_INLINE bool isVariable() const { return !mIsField; }
58
isField__anon45ec70e50111::VarDecl59 ANGLE_INLINE bool isField() const { return mIsField; }
60
type__anon45ec70e50111::VarDecl61 const TType &type() const { return isField() ? *field().type() : variable().getType(); }
62
symbolType__anon45ec70e50111::VarDecl63 SymbolType symbolType() const
64 {
65 return isField() ? field().symbolType() : variable().symbolType();
66 }
67
68 private:
69 union
70 {
71 const TVariable *mVariable;
72 const TField *mField;
73 };
74 bool mIsField;
75 };
76
77 class GenMetalTraverser : public TIntermTraverser
78 {
79 public:
80 ~GenMetalTraverser() override;
81
82 GenMetalTraverser(const TCompiler &compiler,
83 Sink &out,
84 IdGen &idGen,
85 const PipelineStructs &pipelineStructs,
86 SymbolEnv &symbolEnv,
87 const ShCompileOptions &compileOptions);
88
89 void visitSymbol(TIntermSymbol *) override;
90 void visitConstantUnion(TIntermConstantUnion *) override;
91 bool visitSwizzle(Visit, TIntermSwizzle *) override;
92 bool visitBinary(Visit, TIntermBinary *) override;
93 bool visitUnary(Visit, TIntermUnary *) override;
94 bool visitTernary(Visit, TIntermTernary *) override;
95 bool visitIfElse(Visit, TIntermIfElse *) override;
96 bool visitSwitch(Visit, TIntermSwitch *) override;
97 bool visitCase(Visit, TIntermCase *) override;
98 void visitFunctionPrototype(TIntermFunctionPrototype *) override;
99 bool visitFunctionDefinition(Visit, TIntermFunctionDefinition *) override;
100 bool visitAggregate(Visit, TIntermAggregate *) override;
101 bool visitBlock(Visit, TIntermBlock *) override;
102 bool visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *) override;
103 bool visitDeclaration(Visit, TIntermDeclaration *) override;
104 bool visitLoop(Visit, TIntermLoop *) override;
105 bool visitForLoop(TIntermLoop *);
106 bool visitWhileLoop(TIntermLoop *);
107 bool visitDoWhileLoop(TIntermLoop *);
108 bool visitBranch(Visit, TIntermBranch *) override;
109
110 private:
111 using FuncToName = std::map<ImmutableString, Name>;
112 static FuncToName BuildFuncToName();
113
114 struct EmitVariableDeclarationConfig
115 {
116 bool isParameter = false;
117 bool isMainParameter = false;
118 bool emitPostQualifier = false;
119 bool isPacked = false;
120 bool disableStructSpecifier = false;
121 bool isUBO = false;
122 const AddressSpace *isPointer = nullptr;
123 const AddressSpace *isReference = nullptr;
124 };
125
126 struct EmitTypeConfig
127 {
128 const EmitVariableDeclarationConfig *evdConfig = nullptr;
129 };
130
131 void emitIndentation();
132 void emitOpeningPointerParen();
133 void emitClosingPointerParen();
134 void emitFunctionSignature(const TFunction &func);
135 void emitFunctionReturn(const TFunction &func);
136 void emitFunctionParameter(const TFunction &func, const TVariable ¶m);
137
138 void emitNameOf(const TField &object);
139 void emitNameOf(const TSymbol &object);
140 void emitNameOf(const VarDecl &object);
141
142 void emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig);
143 void emitType(const TType &type, const EmitTypeConfig &etConfig);
144 void emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
145 const VarDecl &decl,
146 const TQualifier qualifier);
147
148 void emitLoopBody(TIntermBlock *bodyNode);
149
150 struct FieldAnnotationIndices
151 {
152 size_t attribute = 0;
153 size_t color = 0;
154 };
155
156 void emitFieldDeclaration(const TField &field,
157 const TStructure &parent,
158 FieldAnnotationIndices &annotationIndices);
159 void emitAttributeDeclaration(const TField &field, FieldAnnotationIndices &annotationIndices);
160 void emitUniformBufferDeclaration(const TField &field,
161 FieldAnnotationIndices &annotationIndices);
162 void emitStructDeclaration(const TType &type);
163 void emitOrdinaryVariableDeclaration(const VarDecl &decl,
164 const EmitVariableDeclarationConfig &evdConfig);
165 void emitVariableDeclaration(const VarDecl &decl,
166 const EmitVariableDeclarationConfig &evdConfig);
167
168 void emitOpenBrace();
169 void emitCloseBrace();
170
171 void groupedTraverse(TIntermNode &node);
172
173 const TField &getDirectField(const TFieldListCollection &fieldsNode,
174 const TConstantUnion &index);
175 const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
176
177 const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
178 const size_t size);
179
180 const TConstantUnion *emitConstantUnion(const TType &type, const TConstantUnion *constUnion);
181
182 void emitSingleConstant(const TConstantUnion *const constUnion);
183
184 private:
185 Sink &mOut;
186 const TCompiler &mCompiler;
187 const PipelineStructs &mPipelineStructs;
188 SymbolEnv &mSymbolEnv;
189 IdGen &mIdGen;
190 int mIndentLevel = -1;
191 int mLastIndentationPos = -1;
192 int mOpenPointerParenCount = 0;
193 bool mParentIsSwitch = false;
194 bool isTraversingVertexMain = false;
195 bool mTemporarilyDisableSemicolon = false;
196 std::unordered_map<const TSymbol *, Name> mRenamedSymbols;
197 const FuncToName mFuncToName = BuildFuncToName();
198 size_t mMainTextureIndex = 0;
199 size_t mMainSamplerIndex = 0;
200 size_t mMainUniformBufferIndex = 0;
201 size_t mDriverUniformsBindingIndex = 0;
202 size_t mUBOArgumentBufferBindingIndex = 0;
203 bool mRasterOrderGroupsSupported = false;
204 bool mInjectAsmStatementIntoLoopBodies = false;
205 };
206 } // anonymous namespace
207
~GenMetalTraverser()208 GenMetalTraverser::~GenMetalTraverser()
209 {
210 ASSERT(mIndentLevel == -1);
211 ASSERT(!mParentIsSwitch);
212 ASSERT(mOpenPointerParenCount == 0);
213 }
214
GenMetalTraverser(const TCompiler & compiler,Sink & out,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ShCompileOptions & compileOptions)215 GenMetalTraverser::GenMetalTraverser(const TCompiler &compiler,
216 Sink &out,
217 IdGen &idGen,
218 const PipelineStructs &pipelineStructs,
219 SymbolEnv &symbolEnv,
220 const ShCompileOptions &compileOptions)
221 : TIntermTraverser(true, false, false),
222 mOut(out),
223 mCompiler(compiler),
224 mPipelineStructs(pipelineStructs),
225 mSymbolEnv(symbolEnv),
226 mIdGen(idGen),
227 mMainUniformBufferIndex(compileOptions.metal.defaultUniformsBindingIndex),
228 mDriverUniformsBindingIndex(compileOptions.metal.driverUniformsBindingIndex),
229 mUBOArgumentBufferBindingIndex(compileOptions.metal.UBOArgumentBufferBindingIndex),
230 mRasterOrderGroupsSupported(compileOptions.pls.fragmentSyncType ==
231 ShFragmentSynchronizationType::RasterOrderGroups_Metal),
232 mInjectAsmStatementIntoLoopBodies(compileOptions.metal.injectAsmStatementIntoLoopBodies)
233 {}
234
emitIndentation()235 void GenMetalTraverser::emitIndentation()
236 {
237 ASSERT(mIndentLevel >= 0);
238
239 if (mLastIndentationPos == mOut.size())
240 {
241 return; // Line is already indented.
242 }
243
244 for (int i = 0; i < mIndentLevel; ++i)
245 {
246 mOut << " ";
247 }
248
249 mLastIndentationPos = mOut.size();
250 }
251
emitOpeningPointerParen()252 void GenMetalTraverser::emitOpeningPointerParen()
253 {
254 mOut << "(*";
255 mOpenPointerParenCount++;
256 }
257
emitClosingPointerParen()258 void GenMetalTraverser::emitClosingPointerParen()
259 {
260 if (mOpenPointerParenCount > 0)
261 {
262 mOut << ")";
263 mOpenPointerParenCount--;
264 }
265 }
266
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)267 static const char *GetOperatorString(TOperator op,
268 const TType &resultType,
269 const TType *argType0,
270 const TType *argType1,
271 const TType *argType2)
272 {
273 switch (op)
274 {
275 case TOperator::EOpComma:
276 return ",";
277 case TOperator::EOpAssign:
278 return "=";
279 case TOperator::EOpInitialize:
280 return "=";
281 case TOperator::EOpAddAssign:
282 return "+=";
283 case TOperator::EOpSubAssign:
284 return "-=";
285 case TOperator::EOpMulAssign:
286 return "*=";
287 case TOperator::EOpDivAssign:
288 return "/=";
289 case TOperator::EOpIModAssign:
290 return "%=";
291 case TOperator::EOpBitShiftLeftAssign:
292 return "<<="; // TODO: Check logical vs arithmetic shifting.
293 case TOperator::EOpBitShiftRightAssign:
294 return ">>="; // TODO: Check logical vs arithmetic shifting.
295 case TOperator::EOpBitwiseAndAssign:
296 return "&=";
297 case TOperator::EOpBitwiseXorAssign:
298 return "^=";
299 case TOperator::EOpBitwiseOrAssign:
300 return "|=";
301 case TOperator::EOpAdd:
302 return "+";
303 case TOperator::EOpSub:
304 return "-";
305 case TOperator::EOpMul:
306 return "*";
307 case TOperator::EOpDiv:
308 return "/";
309 case TOperator::EOpIMod:
310 return "%";
311 case TOperator::EOpBitShiftLeft:
312 return "<<"; // TODO: Check logical vs arithmetic shifting.
313 case TOperator::EOpBitShiftRight:
314 return ">>"; // TODO: Check logical vs arithmetic shifting.
315 case TOperator::EOpBitwiseAnd:
316 return "&";
317 case TOperator::EOpBitwiseXor:
318 return "^";
319 case TOperator::EOpBitwiseOr:
320 return "|";
321 case TOperator::EOpLessThan:
322 return "<";
323 case TOperator::EOpGreaterThan:
324 return ">";
325 case TOperator::EOpLessThanEqual:
326 return "<=";
327 case TOperator::EOpGreaterThanEqual:
328 return ">=";
329 case TOperator::EOpLessThanComponentWise:
330 return "<";
331 case TOperator::EOpLessThanEqualComponentWise:
332 return "<=";
333 case TOperator::EOpGreaterThanEqualComponentWise:
334 return ">=";
335 case TOperator::EOpGreaterThanComponentWise:
336 return ">";
337 case TOperator::EOpLogicalOr:
338 return "||";
339 case TOperator::EOpLogicalXor:
340 return "!=/*xor*/"; // XXX: This might need to be handled differently for some obtuse
341 // use case.
342 case TOperator::EOpLogicalAnd:
343 return "&&";
344 case TOperator::EOpNegative:
345 return "-";
346 case TOperator::EOpPositive:
347 if (argType0->isMatrix())
348 {
349 return "";
350 }
351 return "+";
352 case TOperator::EOpLogicalNot:
353 return "!";
354 case TOperator::EOpNotComponentWise:
355 return "!";
356 case TOperator::EOpBitwiseNot:
357 return "~";
358 case TOperator::EOpPostIncrement:
359 return "++";
360 case TOperator::EOpPostDecrement:
361 return "--";
362 case TOperator::EOpPreIncrement:
363 return "++";
364 case TOperator::EOpPreDecrement:
365 return "--";
366 case TOperator::EOpVectorTimesScalarAssign:
367 return "*=";
368 case TOperator::EOpVectorTimesMatrixAssign:
369 return "*=";
370 case TOperator::EOpMatrixTimesScalarAssign:
371 return "*=";
372 case TOperator::EOpMatrixTimesMatrixAssign:
373 return "*=";
374 case TOperator::EOpVectorTimesScalar:
375 return "*";
376 case TOperator::EOpVectorTimesMatrix:
377 return "*";
378 case TOperator::EOpMatrixTimesVector:
379 return "*";
380 case TOperator::EOpMatrixTimesScalar:
381 return "*";
382 case TOperator::EOpMatrixTimesMatrix:
383 return "*";
384 case TOperator::EOpEqualComponentWise:
385 return "==";
386 case TOperator::EOpNotEqualComponentWise:
387 return "!=";
388
389 case TOperator::EOpEqual:
390 if ((argType0->getStruct() && argType1->getStruct()) &&
391 (argType0->isArray() && argType1->isArray()))
392 {
393 return "ANGLE_equalStructArray";
394 }
395
396 if ((argType0->isVector() && argType1->isVector()) ||
397 (argType0->getStruct() && argType1->getStruct()) ||
398 (argType0->isArray() && argType1->isArray()) ||
399 (argType0->isMatrix() && argType1->isMatrix()))
400
401 {
402 return "ANGLE_equal";
403 }
404
405 return "==";
406
407 case TOperator::EOpNotEqual:
408 if ((argType0->getStruct() && argType1->getStruct()) &&
409 (argType0->isArray() && argType1->isArray()))
410 {
411 return "ANGLE_notEqualStructArray";
412 }
413
414 if ((argType0->isVector() && argType1->isVector()) ||
415 (argType0->isArray() && argType1->isArray()) ||
416 (argType0->isMatrix() && argType1->isMatrix()))
417 {
418 return "ANGLE_notEqual";
419 }
420 else if (argType0->getStruct() && argType1->getStruct())
421 {
422 return "ANGLE_notEqualStruct";
423 }
424 return "!=";
425
426 case TOperator::EOpKill:
427 UNIMPLEMENTED();
428 return "kill";
429 case TOperator::EOpReturn:
430 return "return";
431 case TOperator::EOpBreak:
432 return "break";
433 case TOperator::EOpContinue:
434 return "continue";
435
436 case TOperator::EOpRadians:
437 return "ANGLE_radians";
438 case TOperator::EOpDegrees:
439 return "ANGLE_degrees";
440 case TOperator::EOpAtan:
441 return argType1 == nullptr ? "metal::atan" : "metal::atan2";
442 case TOperator::EOpMod:
443 return "ANGLE_mod"; // differs from metal::mod
444 case TOperator::EOpRefract:
445 return argType0->isVector() ? "metal::refract" : "ANGLE_refract_scalar";
446 case TOperator::EOpDistance:
447 return argType0->isVector() ? "metal::distance" : "ANGLE_distance_scalar";
448 case TOperator::EOpLength:
449 return argType0->isVector() ? "metal::length" : "metal::abs";
450 case TOperator::EOpDot:
451 return argType0->isVector() ? "metal::dot" : "*";
452 case TOperator::EOpNormalize:
453 return argType0->isVector() ? "metal::fast::normalize" : "metal::sign";
454 case TOperator::EOpFaceforward:
455 return argType0->isVector() ? "metal::faceforward" : "ANGLE_faceforward_scalar";
456 case TOperator::EOpReflect:
457 return argType0->isVector() ? "metal::reflect" : "ANGLE_reflect_scalar";
458 case TOperator::EOpMatrixCompMult:
459 return "ANGLE_componentWiseMultiply";
460 case TOperator::EOpOuterProduct:
461 return "ANGLE_outerProduct";
462 case TOperator::EOpSign:
463 return argType0->getBasicType() == EbtFloat ? "metal::sign" : "ANGLE_sign_int";
464
465 case TOperator::EOpAbs:
466 return "metal::abs";
467 case TOperator::EOpAll:
468 return "metal::all";
469 case TOperator::EOpAny:
470 return "metal::any";
471 case TOperator::EOpSin:
472 return "metal::sin";
473 case TOperator::EOpCos:
474 return "metal::cos";
475 case TOperator::EOpTan:
476 return "metal::tan";
477 case TOperator::EOpAsin:
478 return "metal::asin";
479 case TOperator::EOpAcos:
480 return "metal::acos";
481 case TOperator::EOpSinh:
482 return "metal::sinh";
483 case TOperator::EOpCosh:
484 return "metal::cosh";
485 case TOperator::EOpTanh:
486 return resultType.getPrecision() == TPrecision::EbpHigh ? "metal::precise::tanh"
487 : "metal::tanh";
488 case TOperator::EOpAsinh:
489 return "metal::asinh";
490 case TOperator::EOpAcosh:
491 return "metal::acosh";
492 case TOperator::EOpAtanh:
493 return "metal::atanh";
494 case TOperator::EOpFma:
495 return "metal::fma";
496 case TOperator::EOpPow:
497 return "metal::powr"; // GLSL's pow excludes negative x
498 case TOperator::EOpExp:
499 return "metal::exp";
500 case TOperator::EOpExp2:
501 return "metal::exp2";
502 case TOperator::EOpLog:
503 return "metal::log";
504 case TOperator::EOpLog2:
505 return "metal::log2";
506 case TOperator::EOpSqrt:
507 return "metal::sqrt";
508 case TOperator::EOpFloor:
509 return "metal::floor";
510 case TOperator::EOpTrunc:
511 return "metal::trunc";
512 case TOperator::EOpCeil:
513 return "metal::ceil";
514 case TOperator::EOpFract:
515 return "metal::fract";
516 case TOperator::EOpMin:
517 return "metal::min";
518 case TOperator::EOpMax:
519 return "metal::max";
520 case TOperator::EOpRound:
521 return "metal::round";
522 case TOperator::EOpRoundEven:
523 return "metal::rint";
524 case TOperator::EOpClamp:
525 return "metal::clamp"; // TODO fast vs precise namespace
526 case TOperator::EOpSaturate:
527 return "metal::saturate"; // TODO fast vs precise namespace
528 case TOperator::EOpMix:
529 if (!argType1->isScalar() && argType2 && argType2->getBasicType() == EbtBool)
530 {
531 return "ANGLE_mix_bool";
532 }
533 return "metal::mix";
534 case TOperator::EOpStep:
535 return "metal::step";
536 case TOperator::EOpSmoothstep:
537 return "metal::smoothstep";
538 case TOperator::EOpModf:
539 return "metal::modf";
540 case TOperator::EOpIsnan:
541 return "metal::isnan";
542 case TOperator::EOpIsinf:
543 return "metal::isinf";
544 case TOperator::EOpLdexp:
545 return "metal::ldexp";
546 case TOperator::EOpFrexp:
547 return "metal::frexp";
548 case TOperator::EOpInversesqrt:
549 return "metal::rsqrt";
550 case TOperator::EOpCross:
551 return "metal::cross";
552 case TOperator::EOpDFdx:
553 return "metal::dfdx";
554 case TOperator::EOpDFdy:
555 return "metal::dfdy";
556 case TOperator::EOpFwidth:
557 return "metal::fwidth";
558 case TOperator::EOpTranspose:
559 return "metal::transpose";
560 case TOperator::EOpDeterminant:
561 return "metal::determinant";
562
563 case TOperator::EOpInverse:
564 return "ANGLE_inverse";
565
566 case TOperator::EOpInterpolateAtCentroid:
567 return "ANGLE_interpolateAtCentroid";
568 case TOperator::EOpInterpolateAtSample:
569 return "ANGLE_interpolateAtSample";
570 case TOperator::EOpInterpolateAtOffset:
571 return "ANGLE_interpolateAtOffset";
572 case TOperator::EOpInterpolateAtCenter:
573 return "ANGLE_interpolateAtCenter";
574
575 case TOperator::EOpFloatBitsToInt:
576 case TOperator::EOpFloatBitsToUint:
577 case TOperator::EOpIntBitsToFloat:
578 case TOperator::EOpUintBitsToFloat:
579 {
580 #define RETURN_AS_TYPE_SCALAR() \
581 do \
582 switch (resultType.getBasicType()) \
583 { \
584 case TBasicType::EbtInt: \
585 return "as_type<int>"; \
586 case TBasicType::EbtUInt: \
587 return "as_type<uint32_t>"; \
588 case TBasicType::EbtFloat: \
589 return "as_type<float>"; \
590 default: \
591 UNIMPLEMENTED(); \
592 return "TOperator_TODO"; \
593 } \
594 while (false)
595
596 #define RETURN_AS_TYPE(post) \
597 do \
598 switch (resultType.getBasicType()) \
599 { \
600 case TBasicType::EbtInt: \
601 return "as_type<int" post ">"; \
602 case TBasicType::EbtUInt: \
603 return "as_type<uint" post ">"; \
604 case TBasicType::EbtFloat: \
605 return "as_type<float" post ">"; \
606 default: \
607 UNIMPLEMENTED(); \
608 return "TOperator_TODO"; \
609 } \
610 while (false)
611
612 if (resultType.isScalar())
613 {
614 RETURN_AS_TYPE_SCALAR();
615 }
616 else if (resultType.isVector())
617 {
618 switch (resultType.getNominalSize())
619 {
620 case 2:
621 RETURN_AS_TYPE("2");
622 case 3:
623 RETURN_AS_TYPE("3");
624 case 4:
625 RETURN_AS_TYPE("4");
626 default:
627 UNREACHABLE();
628 return nullptr;
629 }
630 }
631 else
632 {
633 UNIMPLEMENTED();
634 return "TOperator_TODO";
635 }
636
637 #undef RETURN_AS_TYPE
638 #undef RETURN_AS_TYPE_SCALAR
639 }
640
641 case TOperator::EOpPackUnorm2x16:
642 return "metal::pack_float_to_unorm2x16";
643 case TOperator::EOpPackSnorm2x16:
644 return "metal::pack_float_to_snorm2x16";
645
646 case TOperator::EOpPackUnorm4x8:
647 return "metal::pack_float_to_unorm4x8";
648 case TOperator::EOpPackSnorm4x8:
649 return "metal::pack_float_to_snorm4x8";
650
651 case TOperator::EOpUnpackUnorm2x16:
652 return "metal::unpack_unorm2x16_to_float";
653 case TOperator::EOpUnpackSnorm2x16:
654 return "metal::unpack_snorm2x16_to_float";
655
656 case TOperator::EOpUnpackUnorm4x8:
657 return "metal::unpack_unorm4x8_to_float";
658 case TOperator::EOpUnpackSnorm4x8:
659 return "metal::unpack_snorm4x8_to_float";
660
661 case TOperator::EOpPackHalf2x16:
662 return "ANGLE_pack_half_2x16";
663 case TOperator::EOpUnpackHalf2x16:
664 return "ANGLE_unpack_half_2x16";
665
666 case TOperator::EOpNumSamples:
667 return "metal::get_num_samples";
668 case TOperator::EOpSamplePosition:
669 return "metal::get_sample_position";
670
671 case TOperator::EOpBitfieldExtract:
672 case TOperator::EOpBitfieldInsert:
673 case TOperator::EOpBitfieldReverse:
674 case TOperator::EOpBitCount:
675 case TOperator::EOpFindLSB:
676 case TOperator::EOpFindMSB:
677 case TOperator::EOpUaddCarry:
678 case TOperator::EOpUsubBorrow:
679 case TOperator::EOpUmulExtended:
680 case TOperator::EOpImulExtended:
681 case TOperator::EOpBarrier:
682 case TOperator::EOpMemoryBarrier:
683 case TOperator::EOpMemoryBarrierAtomicCounter:
684 case TOperator::EOpMemoryBarrierBuffer:
685 case TOperator::EOpMemoryBarrierShared:
686 case TOperator::EOpGroupMemoryBarrier:
687 case TOperator::EOpAtomicAdd:
688 case TOperator::EOpAtomicMin:
689 case TOperator::EOpAtomicMax:
690 case TOperator::EOpAtomicAnd:
691 case TOperator::EOpAtomicOr:
692 case TOperator::EOpAtomicXor:
693 case TOperator::EOpAtomicExchange:
694 case TOperator::EOpAtomicCompSwap:
695 case TOperator::EOpEmitVertex:
696 case TOperator::EOpEndPrimitive:
697 case TOperator::EOpArrayLength:
698 UNIMPLEMENTED();
699 return "TOperator_TODO";
700
701 case TOperator::EOpNull:
702 case TOperator::EOpConstruct:
703 case TOperator::EOpCallFunctionInAST:
704 case TOperator::EOpCallInternalRawFunction:
705 case TOperator::EOpIndexDirect:
706 case TOperator::EOpIndexIndirect:
707 case TOperator::EOpIndexDirectStruct:
708 case TOperator::EOpIndexDirectInterfaceBlock:
709 UNREACHABLE();
710 return nullptr;
711 default:
712 // Any other built-in function.
713 return nullptr;
714 }
715 }
716
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)717 static bool IsSymbolicOperator(TOperator op,
718 const TType &resultType,
719 const TType *argType0,
720 const TType *argType1)
721 {
722 const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
723 if (operatorString == nullptr)
724 {
725 return false;
726 }
727 return !std::isalnum(operatorString[0]);
728 }
729
AsSpecificBinaryNode(TIntermNode & node,TOperator op)730 static TIntermBinary *AsSpecificBinaryNode(TIntermNode &node, TOperator op)
731 {
732 TIntermBinary *binaryNode = node.getAsBinaryNode();
733 if (binaryNode)
734 {
735 return binaryNode->getOp() == op ? binaryNode : nullptr;
736 }
737 return nullptr;
738 }
739
Parenthesize(TIntermNode & node)740 static bool Parenthesize(TIntermNode &node)
741 {
742 if (node.getAsSymbolNode())
743 {
744 return false;
745 }
746 if (node.getAsConstantUnion())
747 {
748 return false;
749 }
750 if (node.getAsAggregate())
751 {
752 return false;
753 }
754 if (node.getAsSwizzleNode())
755 {
756 return false;
757 }
758
759 if (TIntermUnary *unaryNode = node.getAsUnaryNode())
760 {
761 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
762 const TType &resultType = unaryNode->getType();
763 const TType &argType = unaryNode->getOperand()->getType();
764 return IsSymbolicOperator(unaryNode->getOp(), resultType, &argType, nullptr);
765 }
766
767 if (TIntermBinary *binaryNode = node.getAsBinaryNode())
768 {
769 // TODO: Use a precedence and associativity rules instead of this ad-hoc impl.
770 const TOperator op = binaryNode->getOp();
771 switch (op)
772 {
773 case TOperator::EOpIndexDirectStruct:
774 case TOperator::EOpIndexDirectInterfaceBlock:
775 case TOperator::EOpIndexDirect:
776 case TOperator::EOpIndexIndirect:
777 return Parenthesize(*binaryNode->getLeft());
778
779 case TOperator::EOpAssign:
780 case TOperator::EOpInitialize:
781 return AsSpecificBinaryNode(*binaryNode->getRight(), TOperator::EOpComma);
782
783 default:
784 {
785 const TType &resultType = binaryNode->getType();
786 const TType &leftType = binaryNode->getLeft()->getType();
787 const TType &rightType = binaryNode->getRight()->getType();
788 return IsSymbolicOperator(binaryNode->getOp(), resultType, &leftType, &rightType);
789 }
790 }
791 }
792
793 return true;
794 }
795
groupedTraverse(TIntermNode & node)796 void GenMetalTraverser::groupedTraverse(TIntermNode &node)
797 {
798 const bool emitParens = Parenthesize(node);
799
800 if (emitParens)
801 {
802 mOut << "(";
803 }
804
805 node.traverse(this);
806
807 if (emitParens)
808 {
809 mOut << ")";
810 }
811 }
812
emitPostQualifier(const EmitVariableDeclarationConfig & evdConfig,const VarDecl & decl,const TQualifier qualifier)813 void GenMetalTraverser::emitPostQualifier(const EmitVariableDeclarationConfig &evdConfig,
814 const VarDecl &decl,
815 const TQualifier qualifier)
816 {
817 bool isInvariant = false;
818 switch (qualifier)
819 {
820 case TQualifier::EvqPosition:
821 isInvariant = decl.type().isInvariant();
822 [[fallthrough]];
823 case TQualifier::EvqFragCoord:
824 mOut << " [[position]]";
825 break;
826
827 case TQualifier::EvqClipDistance:
828 mOut << " [[clip_distance]] [" << decl.type().getOutermostArraySize() << "]";
829 break;
830
831 case TQualifier::EvqPointSize:
832 mOut << " [[point_size]]";
833 break;
834
835 case TQualifier::EvqVertexID:
836 if (evdConfig.isMainParameter)
837 {
838 mOut << " [[vertex_id]]";
839 }
840 break;
841
842 case TQualifier::EvqPointCoord:
843 if (evdConfig.isMainParameter)
844 {
845 mOut << " [[point_coord]]";
846 }
847 break;
848
849 case TQualifier::EvqFrontFacing:
850 if (evdConfig.isMainParameter)
851 {
852 mOut << " [[front_facing]]";
853 }
854 break;
855
856 case TQualifier::EvqSampleID:
857 if (evdConfig.isMainParameter)
858 {
859 mOut << " [[sample_id]]";
860 }
861 break;
862
863 case TQualifier::EvqSampleMaskIn:
864 if (evdConfig.isMainParameter)
865 {
866 mOut << " [[sample_mask]]";
867 }
868 break;
869
870 default:
871 break;
872 }
873
874 if (isInvariant)
875 {
876 mOut << " [[invariant]]";
877
878 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
879 reflection->hasInvariance = true;
880 }
881 }
882
emitLoopBody(TIntermBlock * bodyNode)883 void GenMetalTraverser::emitLoopBody(TIntermBlock *bodyNode)
884 {
885 if (mInjectAsmStatementIntoLoopBodies)
886 {
887 emitOpenBrace();
888
889 emitIndentation();
890 mOut << "__asm__(\"\");\n";
891 }
892
893 bodyNode->traverse(this);
894
895 if (mInjectAsmStatementIntoLoopBodies)
896 {
897 emitCloseBrace();
898 }
899 }
900
EmitName(Sink & out,const Name & name)901 static void EmitName(Sink &out, const Name &name)
902 {
903 #if defined(ANGLE_ENABLE_ASSERTS)
904 DebugSink::EscapedSink escapedOut(out.escape());
905 #else
906 TInfoSinkBase &escapedOut = out;
907 #endif
908 name.emit(escapedOut);
909 }
910
emitNameOf(const TField & object)911 void GenMetalTraverser::emitNameOf(const TField &object)
912 {
913 EmitName(mOut, Name(object));
914 }
915
emitNameOf(const TSymbol & object)916 void GenMetalTraverser::emitNameOf(const TSymbol &object)
917 {
918 auto it = mRenamedSymbols.find(&object);
919 if (it == mRenamedSymbols.end())
920 {
921 EmitName(mOut, Name(object));
922 }
923 else
924 {
925 EmitName(mOut, it->second);
926 }
927 }
928
emitNameOf(const VarDecl & object)929 void GenMetalTraverser::emitNameOf(const VarDecl &object)
930 {
931 if (object.isField())
932 {
933 emitNameOf(object.field());
934 }
935 else
936 {
937 emitNameOf(object.variable());
938 }
939 }
940
emitBareTypeName(const TType & type,const EmitTypeConfig & etConfig)941 void GenMetalTraverser::emitBareTypeName(const TType &type, const EmitTypeConfig &etConfig)
942 {
943 const TBasicType basicType = type.getBasicType();
944
945 switch (basicType)
946 {
947 case TBasicType::EbtVoid:
948 case TBasicType::EbtBool:
949 case TBasicType::EbtFloat:
950 case TBasicType::EbtInt:
951 {
952 mOut << type.getBasicString();
953 break;
954 }
955 case TBasicType::EbtUInt:
956 {
957 if (type.isScalar())
958 {
959 mOut << "uint32_t";
960 }
961 else
962 {
963 mOut << type.getBasicString();
964 }
965 }
966 break;
967
968 case TBasicType::EbtStruct:
969 {
970 const TStructure &structure = *type.getStruct();
971 emitNameOf(structure);
972 }
973 break;
974
975 case TBasicType::EbtInterfaceBlock:
976 {
977 const TInterfaceBlock &interfaceBlock = *type.getInterfaceBlock();
978 emitNameOf(interfaceBlock);
979 }
980 break;
981
982 default:
983 {
984 if (IsSampler(basicType))
985 {
986 if (etConfig.evdConfig && etConfig.evdConfig->isMainParameter)
987 {
988 EmitName(mOut, GetTextureTypeName(basicType));
989 }
990 else
991 {
992 const TStructure &env = mSymbolEnv.getTextureEnv(basicType);
993 emitNameOf(env);
994 }
995 }
996 else if (IsImage(basicType))
997 {
998 mOut << "metal::texture2d<";
999 switch (type.getBasicType())
1000 {
1001 case EbtImage2D:
1002 mOut << "float";
1003 break;
1004 case EbtIImage2D:
1005 mOut << "int";
1006 break;
1007 case EbtUImage2D:
1008 mOut << "uint";
1009 break;
1010 default:
1011 UNIMPLEMENTED();
1012 break;
1013 }
1014 if (type.getMemoryQualifier().readonly || type.getMemoryQualifier().writeonly)
1015 {
1016 UNIMPLEMENTED();
1017 }
1018 mOut << ", metal::access::read_write>";
1019 }
1020 else
1021 {
1022 UNIMPLEMENTED();
1023 }
1024 }
1025 }
1026 }
1027
emitType(const TType & type,const EmitTypeConfig & etConfig)1028 void GenMetalTraverser::emitType(const TType &type, const EmitTypeConfig &etConfig)
1029 {
1030 const bool isUBO = etConfig.evdConfig ? etConfig.evdConfig->isUBO : false;
1031 if (etConfig.evdConfig)
1032 {
1033 const auto &evdConfig = *etConfig.evdConfig;
1034 if (isUBO)
1035 {
1036 if (type.isArray())
1037 {
1038 mOut << "metal::array<";
1039 }
1040 }
1041 if (evdConfig.isPointer)
1042 {
1043 mOut << toString(*evdConfig.isPointer);
1044 mOut << " ";
1045 }
1046 else if (evdConfig.isReference)
1047 {
1048 mOut << toString(*evdConfig.isReference);
1049 mOut << " ";
1050 }
1051 }
1052
1053 if (!isUBO)
1054 {
1055 if (type.isArray())
1056 {
1057 mOut << "metal::array<";
1058 }
1059 }
1060
1061 if (type.isInterpolant())
1062 {
1063 mOut << "metal::interpolant<";
1064 }
1065
1066 if (type.isVector() || type.isMatrix())
1067 {
1068 mOut << "metal::";
1069 }
1070
1071 if (etConfig.evdConfig && etConfig.evdConfig->isPacked)
1072 {
1073 mOut << "packed_";
1074 }
1075
1076 emitBareTypeName(type, etConfig);
1077
1078 if (type.isVector())
1079 {
1080 mOut << static_cast<uint32_t>(type.getNominalSize());
1081 }
1082 else if (type.isMatrix())
1083 {
1084 mOut << static_cast<uint32_t>(type.getCols()) << "x"
1085 << static_cast<uint32_t>(type.getRows());
1086 }
1087
1088 if (type.isInterpolant())
1089 {
1090 mOut << ", metal::interpolation::";
1091 switch (type.getQualifier())
1092 {
1093 case EvqNoPerspectiveIn:
1094 case EvqNoPerspectiveCentroidIn:
1095 case EvqNoPerspectiveSampleIn:
1096 mOut << "no_";
1097 break;
1098 default:
1099 break;
1100 }
1101 mOut << "perspective>";
1102 }
1103
1104 if (!isUBO)
1105 {
1106 if (type.isArray())
1107 {
1108 for (auto size : type.getArraySizes())
1109 {
1110 mOut << ", " << size;
1111 }
1112 mOut << ">";
1113 }
1114 }
1115
1116 if (etConfig.evdConfig)
1117 {
1118 const auto &evdConfig = *etConfig.evdConfig;
1119 if (evdConfig.isPointer)
1120 {
1121 mOut << " *";
1122 }
1123 else if (evdConfig.isReference)
1124 {
1125 mOut << " &";
1126 }
1127 if (isUBO)
1128 {
1129 if (type.isArray())
1130 {
1131 for (auto size : type.getArraySizes())
1132 {
1133 mOut << ", " << size;
1134 }
1135 mOut << ">";
1136 }
1137 }
1138 }
1139 }
1140
emitFieldDeclaration(const TField & field,const TStructure & parent,FieldAnnotationIndices & annotationIndices)1141 void GenMetalTraverser::emitFieldDeclaration(const TField &field,
1142 const TStructure &parent,
1143 FieldAnnotationIndices &annotationIndices)
1144 {
1145 const TType &type = *field.type();
1146 const TBasicType basic = type.getBasicType();
1147
1148 EmitVariableDeclarationConfig evdConfig;
1149 evdConfig.emitPostQualifier = true;
1150 evdConfig.disableStructSpecifier = true;
1151 evdConfig.isPacked = mSymbolEnv.isPacked(field);
1152 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1153 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1154 evdConfig.isReference = mSymbolEnv.isReference(field);
1155 emitVariableDeclaration(VarDecl(field), evdConfig);
1156
1157 const TQualifier qual = type.getQualifier();
1158 switch (qual)
1159 {
1160 case TQualifier::EvqFlatIn:
1161 if (mPipelineStructs.fragmentIn.external == &parent)
1162 {
1163 mOut << " [[flat]]";
1164 TranslatorMetalReflection *reflection =
1165 mtl::getTranslatorMetalReflection(&mCompiler);
1166 reflection->hasFlatInput = true;
1167 }
1168 break;
1169
1170 case TQualifier::EvqNoPerspectiveIn:
1171 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1172 {
1173 mOut << " [[center_no_perspective]]";
1174 }
1175 break;
1176
1177 case TQualifier::EvqCentroidIn:
1178 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1179 {
1180 mOut << " [[centroid_perspective]]";
1181 }
1182 break;
1183
1184 case TQualifier::EvqSampleIn:
1185 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1186 {
1187 mOut << " [[sample_perspective]]";
1188 }
1189 break;
1190
1191 case TQualifier::EvqNoPerspectiveCentroidIn:
1192 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1193 {
1194 mOut << " [[centroid_no_perspective]]";
1195 }
1196 break;
1197
1198 case TQualifier::EvqNoPerspectiveSampleIn:
1199 if (mPipelineStructs.fragmentIn.external == &parent && !type.isInterpolant())
1200 {
1201 mOut << " [[sample_no_perspective]]";
1202 }
1203 break;
1204
1205 case TQualifier::EvqFragColor:
1206 mOut << " [[color(0)]]";
1207 break;
1208
1209 case TQualifier::EvqSecondaryFragColorEXT:
1210 case TQualifier::EvqSecondaryFragDataEXT:
1211 mOut << " [[color(0), index(1)]]";
1212 break;
1213
1214 case TQualifier::EvqFragmentOut:
1215 case TQualifier::EvqFragmentInOut:
1216 case TQualifier::EvqFragData:
1217 if (mPipelineStructs.fragmentOut.external == &parent ||
1218 mPipelineStructs.fragmentOut.externalExtra == &parent)
1219 {
1220 if ((type.isVector() &&
1221 (basic == TBasicType::EbtInt || basic == TBasicType::EbtUInt ||
1222 basic == TBasicType::EbtFloat)) ||
1223 qual == EvqFragData)
1224 {
1225 // The OpenGL ES 3.0 spec says locations must be specified
1226 // unless there is only a single output, in which case the
1227 // location is 0. So, when we get to this point the shader
1228 // will have been rejected if locations are not specified
1229 // and there is more than one output.
1230 const TLayoutQualifier &layoutQualifier = type.getLayoutQualifier();
1231 if (layoutQualifier.locationsSpecified)
1232 {
1233 mOut << " [[color(" << layoutQualifier.location << ")";
1234 ASSERT(layoutQualifier.index >= -1 && layoutQualifier.index <= 1);
1235 if (layoutQualifier.index == 1)
1236 {
1237 mOut << ", index(1)";
1238 }
1239 }
1240 else if (qual == EvqFragData)
1241 {
1242 mOut << " [[color(" << annotationIndices.color++ << ")";
1243 }
1244 else
1245 {
1246 // Either the only output or EXT_blend_func_extended is used;
1247 // actual assignment will happen in UpdateFragmentShaderOutputs.
1248 mOut << " [[" << sh::kUnassignedFragmentOutputString;
1249 }
1250 if (mRasterOrderGroupsSupported && qual == TQualifier::EvqFragmentInOut)
1251 {
1252 // Put fragment inouts in their own raster order group for better
1253 // parallelism.
1254 // NOTE: this is not required for the reads to be ordered and coherent.
1255 // TODO(anglebug.com/40096838): Consider making raster order groups a PLS
1256 // layout qualifier?
1257 mOut << ", raster_order_group(0)";
1258 }
1259 mOut << "]]";
1260 }
1261 }
1262 break;
1263
1264 case TQualifier::EvqFragDepth:
1265 mOut << " [[depth(";
1266 switch (type.getLayoutQualifier().depth)
1267 {
1268 case EdGreater:
1269 mOut << "greater";
1270 break;
1271 case EdLess:
1272 mOut << "less";
1273 break;
1274 default:
1275 mOut << "any";
1276 break;
1277 }
1278 mOut << "), function_constant(" << sh::mtl::kDepthWriteEnabledConstName << ")]]";
1279 break;
1280
1281 case TQualifier::EvqSampleMask:
1282 if (field.symbolType() == SymbolType::AngleInternal)
1283 {
1284 mOut << " [[sample_mask, function_constant("
1285 << sh::mtl::kSampleMaskWriteEnabledConstName << ")]]";
1286 }
1287 break;
1288
1289 default:
1290 break;
1291 }
1292
1293 if (IsImage(type.getBasicType()))
1294 {
1295 if (type.getArraySizeProduct() != 1)
1296 {
1297 UNIMPLEMENTED();
1298 }
1299 mOut << " [[";
1300 if (type.getLayoutQualifier().rasterOrdered)
1301 {
1302 mOut << "raster_order_group(0), ";
1303 }
1304 mOut << "texture(" << mMainTextureIndex << ")]]";
1305 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1306 reflection->addRWTextureBinding(type.getLayoutQualifier().binding,
1307 static_cast<int>(mMainTextureIndex));
1308 ++mMainTextureIndex;
1309 }
1310 }
1311
BuildExternalAttributeIndexMap(const TCompiler & compiler,const PipelineScoped<TStructure> & structure)1312 static std::map<Name, size_t> BuildExternalAttributeIndexMap(
1313 const TCompiler &compiler,
1314 const PipelineScoped<TStructure> &structure)
1315 {
1316 ASSERT(structure.isTotallyFull());
1317
1318 const auto &shaderVars = compiler.getAttributes();
1319 const size_t shaderVarSize = shaderVars.size();
1320 size_t shaderVarIndex = 0;
1321
1322 const auto &externalFields = structure.external->fields();
1323 const size_t externalSize = externalFields.size();
1324 size_t externalIndex = 0;
1325
1326 const auto &internalFields = structure.internal->fields();
1327 const size_t internalSize = internalFields.size();
1328 size_t internalIndex = 0;
1329
1330 // Internal fields are never split. External fields are sometimes split.
1331 ASSERT(externalSize >= internalSize);
1332
1333 // Structures do not contain any inactive fields.
1334 ASSERT(shaderVarSize >= internalSize);
1335
1336 std::map<Name, size_t> externalNameToAttributeIndex;
1337 size_t attributeIndex = 0;
1338
1339 while (internalIndex < internalSize)
1340 {
1341 const TField &internalField = *internalFields[internalIndex];
1342 const Name internalName = Name(internalField);
1343 const TType &internalType = *internalField.type();
1344 while (internalName.rawName() != shaderVars[shaderVarIndex].name &&
1345 internalName.rawName() != shaderVars[shaderVarIndex].mappedName)
1346 {
1347 // This case represents an inactive field.
1348
1349 ++shaderVarIndex;
1350 ASSERT(shaderVarIndex < shaderVarSize);
1351
1352 ++attributeIndex; // TODO: Might need to increment more if shader var type is a matrix.
1353 }
1354
1355 const size_t cols =
1356 (internalType.isMatrix() && !externalFields[externalIndex]->type()->isMatrix())
1357 ? internalType.getCols()
1358 : 1;
1359
1360 for (size_t c = 0; c < cols; ++c)
1361 {
1362 const TField &externalField = *externalFields[externalIndex];
1363 const Name externalName = Name(externalField);
1364
1365 externalNameToAttributeIndex[externalName] = attributeIndex;
1366
1367 ++externalIndex;
1368 ++attributeIndex;
1369 }
1370
1371 ++shaderVarIndex;
1372 ++internalIndex;
1373 }
1374
1375 ASSERT(shaderVarIndex <= shaderVarSize);
1376 ASSERT(externalIndex <= externalSize); // less than if padding was introduced
1377 ASSERT(internalIndex == internalSize);
1378
1379 return externalNameToAttributeIndex;
1380 }
1381
emitAttributeDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1382 void GenMetalTraverser::emitAttributeDeclaration(const TField &field,
1383 FieldAnnotationIndices &annotationIndices)
1384 {
1385 EmitVariableDeclarationConfig evdConfig;
1386 evdConfig.disableStructSpecifier = true;
1387 emitVariableDeclaration(VarDecl(field), evdConfig);
1388 mOut << sh::kUnassignedAttributeString;
1389 }
1390
emitUniformBufferDeclaration(const TField & field,FieldAnnotationIndices & annotationIndices)1391 void GenMetalTraverser::emitUniformBufferDeclaration(const TField &field,
1392 FieldAnnotationIndices &annotationIndices)
1393 {
1394 EmitVariableDeclarationConfig evdConfig;
1395 evdConfig.disableStructSpecifier = true;
1396 evdConfig.isUBO = mSymbolEnv.isUBO(field);
1397 evdConfig.isPointer = mSymbolEnv.isPointer(field);
1398 emitVariableDeclaration(VarDecl(field), evdConfig);
1399 mOut << "[[id(" << annotationIndices.attribute << ")]]";
1400
1401 const TType &type = *field.type();
1402 const int arraySize = type.isArray() ? type.getArraySizeProduct() : 1;
1403
1404 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
1405 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1406 const TStructure *structure = type.getStruct();
1407 const std::string originalName = reflection->getOriginalName(structure->uniqueId().get());
1408 reflection->addUniformBufferBinding(
1409 originalName,
1410 {.bindIndex = annotationIndices.attribute, .arraySize = static_cast<size_t>(arraySize)});
1411
1412 annotationIndices.attribute += arraySize;
1413 }
1414
emitStructDeclaration(const TType & type)1415 void GenMetalTraverser::emitStructDeclaration(const TType &type)
1416 {
1417 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1418 ASSERT(type.isStructSpecifier());
1419
1420 mOut << "struct ";
1421 emitBareTypeName(type, {});
1422
1423 mOut << "\n";
1424 emitOpenBrace();
1425
1426 const TStructure &structure = *type.getStruct();
1427 std::map<Name, size_t> fieldToAttributeIndex;
1428 const bool hasAttributeIndices = mPipelineStructs.vertexIn.external == &structure;
1429 const bool hasUniformBufferIndicies = mPipelineStructs.uniformBuffers.external == &structure;
1430 const bool reclaimUnusedAttributeIndices = mCompiler.getShaderVersion() < 300;
1431
1432 if (hasAttributeIndices)
1433 {
1434 // When attribute aliasing is supported, external attribute struct is filled post-link.
1435 if (mCompiler.supportsAttributeAliasing())
1436 {
1437 mtl::getTranslatorMetalReflection(&mCompiler)->hasAttributeAliasing = true;
1438 mOut << "@@Attrib-Bindings@@\n";
1439 emitCloseBrace();
1440 return;
1441 }
1442
1443 fieldToAttributeIndex =
1444 BuildExternalAttributeIndexMap(mCompiler, mPipelineStructs.vertexIn);
1445 }
1446
1447 FieldAnnotationIndices annotationIndices;
1448
1449 for (const TField *field : structure.fields())
1450 {
1451 emitIndentation();
1452 if (hasAttributeIndices)
1453 {
1454 const auto it = fieldToAttributeIndex.find(Name(*field));
1455 if (it == fieldToAttributeIndex.end())
1456 {
1457 ASSERT(field->symbolType() == SymbolType::AngleInternal);
1458 ASSERT(field->name().beginsWith("_"));
1459 ASSERT(angle::EndsWith(field->name().data(), "_pad"));
1460 emitFieldDeclaration(*field, structure, annotationIndices);
1461 }
1462 else
1463 {
1464 ASSERT(field->symbolType() != SymbolType::AngleInternal ||
1465 !field->name().beginsWith("_") ||
1466 !angle::EndsWith(field->name().data(), "_pad"));
1467 if (!reclaimUnusedAttributeIndices)
1468 {
1469 annotationIndices.attribute = it->second;
1470 }
1471 emitAttributeDeclaration(*field, annotationIndices);
1472 }
1473 }
1474 else if (hasUniformBufferIndicies)
1475 {
1476 emitUniformBufferDeclaration(*field, annotationIndices);
1477 }
1478 else
1479 {
1480 emitFieldDeclaration(*field, structure, annotationIndices);
1481 }
1482 mOut << ";\n";
1483 }
1484
1485 emitCloseBrace();
1486 }
1487
emitOrdinaryVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1488 void GenMetalTraverser::emitOrdinaryVariableDeclaration(
1489 const VarDecl &decl,
1490 const EmitVariableDeclarationConfig &evdConfig)
1491 {
1492 EmitTypeConfig etConfig;
1493 etConfig.evdConfig = &evdConfig;
1494
1495 const TType &type = decl.type();
1496 if (type.getQualifier() == TQualifier::EvqClipDistance)
1497 {
1498 // Clip distance output uses float[n] type instead of metal::array.
1499 // The element count is emitted after the post qualifier.
1500 ASSERT(type.getBasicType() == TBasicType::EbtFloat);
1501 mOut << "float";
1502 }
1503 else if (type.getQualifier() == TQualifier::EvqSampleID && evdConfig.isMainParameter)
1504 {
1505 // Metal's [[sample_id]] must be unsigned
1506 ASSERT(type.getBasicType() == TBasicType::EbtInt);
1507 mOut << "uint32_t";
1508 }
1509 else
1510 {
1511 emitType(type, etConfig);
1512 }
1513 if (decl.symbolType() != SymbolType::Empty)
1514 {
1515 mOut << " ";
1516 emitNameOf(decl);
1517 }
1518 }
1519
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1520 void GenMetalTraverser::emitVariableDeclaration(const VarDecl &decl,
1521 const EmitVariableDeclarationConfig &evdConfig)
1522 {
1523 const SymbolType symbolType = decl.symbolType();
1524 const TType &type = decl.type();
1525 const TBasicType basicType = type.getBasicType();
1526
1527 switch (basicType)
1528 {
1529 case TBasicType::EbtStruct:
1530 {
1531 if (type.isStructSpecifier() && !evdConfig.disableStructSpecifier)
1532 {
1533 // It's invalid to declare a struct inside a function argument. When emitting a
1534 // function parameter, the callsite should set evdConfig.disableStructSpecifier.
1535 ASSERT(!evdConfig.isParameter);
1536 emitStructDeclaration(type);
1537 if (symbolType != SymbolType::Empty)
1538 {
1539 mOut << " ";
1540 emitNameOf(decl);
1541 }
1542 }
1543 else
1544 {
1545 emitOrdinaryVariableDeclaration(decl, evdConfig);
1546 }
1547 }
1548 break;
1549
1550 default:
1551 {
1552 ASSERT(symbolType != SymbolType::Empty || evdConfig.isParameter);
1553 emitOrdinaryVariableDeclaration(decl, evdConfig);
1554 }
1555 }
1556
1557 if (evdConfig.emitPostQualifier)
1558 {
1559 emitPostQualifier(evdConfig, decl, type.getQualifier());
1560 }
1561 }
1562
visitSymbol(TIntermSymbol * symbolNode)1563 void GenMetalTraverser::visitSymbol(TIntermSymbol *symbolNode)
1564 {
1565 const TVariable &var = symbolNode->variable();
1566 const TType &type = var.getType();
1567 ASSERT(var.symbolType() != SymbolType::Empty);
1568
1569 if (type.getBasicType() == TBasicType::EbtVoid)
1570 {
1571 mOut << "/*";
1572 emitNameOf(var);
1573 mOut << "*/";
1574 }
1575 else
1576 {
1577 emitNameOf(var);
1578 }
1579 }
1580
emitSingleConstant(const TConstantUnion * const constUnion)1581 void GenMetalTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
1582 {
1583 switch (constUnion->getType())
1584 {
1585 case TBasicType::EbtBool:
1586 {
1587 mOut << (constUnion->getBConst() ? "true" : "false");
1588 }
1589 break;
1590
1591 case TBasicType::EbtFloat:
1592 {
1593 float value = constUnion->getFConst();
1594 if (std::isnan(value))
1595 {
1596 mOut << "NAN";
1597 }
1598 else if (std::isinf(value))
1599 {
1600 if (value < 0)
1601 {
1602 mOut << "-";
1603 }
1604 mOut << "INFINITY";
1605 }
1606 else
1607 {
1608 mOut << value << "f";
1609 }
1610 }
1611 break;
1612
1613 case TBasicType::EbtInt:
1614 {
1615 mOut << constUnion->getIConst();
1616 }
1617 break;
1618
1619 case TBasicType::EbtUInt:
1620 {
1621 mOut << constUnion->getUConst() << "u";
1622 }
1623 break;
1624
1625 default:
1626 {
1627 UNIMPLEMENTED();
1628 }
1629 }
1630 }
1631
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)1632 const TConstantUnion *GenMetalTraverser::emitConstantUnionArray(
1633 const TConstantUnion *const constUnion,
1634 const size_t size)
1635 {
1636 const TConstantUnion *constUnionIterated = constUnion;
1637 for (size_t i = 0; i < size; i++, constUnionIterated++)
1638 {
1639 emitSingleConstant(constUnionIterated);
1640
1641 if (i != size - 1)
1642 {
1643 mOut << ", ";
1644 }
1645 }
1646 return constUnionIterated;
1647 }
1648
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)1649 const TConstantUnion *GenMetalTraverser::emitConstantUnion(const TType &type,
1650 const TConstantUnion *constUnionBegin)
1651 {
1652 const TConstantUnion *constUnionCurr = constUnionBegin;
1653 const TStructure *structure = type.getStruct();
1654 if (structure)
1655 {
1656 EmitTypeConfig config = EmitTypeConfig{nullptr};
1657 emitType(type, config);
1658 mOut << "{";
1659 const TFieldList &fields = structure->fields();
1660 for (size_t i = 0; i < fields.size(); ++i)
1661 {
1662 const TType *fieldType = fields[i]->type();
1663 constUnionCurr = emitConstantUnion(*fieldType, constUnionCurr);
1664 if (i != fields.size() - 1)
1665 {
1666 mOut << ", ";
1667 }
1668 }
1669 mOut << "}";
1670 }
1671 else
1672 {
1673 size_t size = type.getObjectSize();
1674 bool writeType = size > 1;
1675 if (writeType)
1676 {
1677 EmitTypeConfig config = EmitTypeConfig{nullptr};
1678 emitType(type, config);
1679 mOut << "(";
1680 }
1681 constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
1682 if (writeType)
1683 {
1684 mOut << ")";
1685 }
1686 }
1687 return constUnionCurr;
1688 }
1689
visitConstantUnion(TIntermConstantUnion * constValueNode)1690 void GenMetalTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
1691 {
1692 emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
1693 }
1694
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)1695 bool GenMetalTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
1696 {
1697 groupedTraverse(*swizzleNode->getOperand());
1698 mOut << ".";
1699
1700 {
1701 #if defined(ANGLE_ENABLE_ASSERTS)
1702 DebugSink::EscapedSink escapedOut(mOut.escape());
1703 TInfoSinkBase &out = escapedOut.get();
1704 #else
1705 TInfoSinkBase &out = mOut;
1706 #endif
1707 swizzleNode->writeOffsetsAsXYZW(&out);
1708 }
1709
1710 return false;
1711 }
1712
getDirectField(const TFieldListCollection & fieldListCollection,const TConstantUnion & index)1713 const TField &GenMetalTraverser::getDirectField(const TFieldListCollection &fieldListCollection,
1714 const TConstantUnion &index)
1715 {
1716 ASSERT(index.getType() == TBasicType::EbtInt);
1717
1718 const TFieldList &fieldList = fieldListCollection.fields();
1719 const int indexVal = index.getIConst();
1720 const TField &field = *fieldList[indexVal];
1721
1722 return field;
1723 }
1724
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)1725 const TField &GenMetalTraverser::getDirectField(const TIntermTyped &fieldsNode,
1726 TIntermTyped &indexNode)
1727 {
1728 const TType &fieldsType = fieldsNode.getType();
1729
1730 const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
1731 if (fieldListCollection == nullptr)
1732 {
1733 fieldListCollection = fieldsType.getInterfaceBlock();
1734 }
1735 ASSERT(fieldListCollection);
1736
1737 const TIntermConstantUnion *indexNode_ = indexNode.getAsConstantUnion();
1738 ASSERT(indexNode_);
1739 const TConstantUnion &index = *indexNode_->getConstantValue();
1740
1741 return getDirectField(*fieldListCollection, index);
1742 }
1743
visitBinary(Visit,TIntermBinary * binaryNode)1744 bool GenMetalTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
1745 {
1746 const TOperator op = binaryNode->getOp();
1747 TIntermTyped &leftNode = *binaryNode->getLeft();
1748 TIntermTyped &rightNode = *binaryNode->getRight();
1749
1750 switch (op)
1751 {
1752 case TOperator::EOpIndexDirectStruct:
1753 case TOperator::EOpIndexDirectInterfaceBlock:
1754 {
1755 const TField &field = getDirectField(leftNode, rightNode);
1756 if (mSymbolEnv.isPointer(field) && mSymbolEnv.isUBO(field))
1757 {
1758 emitOpeningPointerParen();
1759 }
1760 groupedTraverse(leftNode);
1761 if (!mSymbolEnv.isPointer(field))
1762 {
1763 emitClosingPointerParen();
1764 }
1765 mOut << ".";
1766 emitNameOf(field);
1767 }
1768 break;
1769
1770 case TOperator::EOpIndexDirect:
1771 case TOperator::EOpIndexIndirect:
1772 {
1773 TType leftType = leftNode.getType();
1774 groupedTraverse(leftNode);
1775 mOut << "[";
1776 const TConstantUnion *constIndex = rightNode.getConstantValue();
1777 // TODO(anglebug.com/42266914): Convert type and bound checks to
1778 // assertions after AST validation is enabled for MSL translation.
1779 if (!leftType.isUnsizedArray() && constIndex != nullptr &&
1780 constIndex->getType() == EbtInt && constIndex->getIConst() >= 0 &&
1781 constIndex->getIConst() < static_cast<int>(leftType.isArray()
1782 ? leftType.getOutermostArraySize()
1783 : leftType.getNominalSize()))
1784 {
1785 emitSingleConstant(constIndex);
1786 }
1787 else
1788 {
1789 mOut << "ANGLE_int_clamp(";
1790 groupedTraverse(rightNode);
1791 mOut << ", 0, ";
1792 if (leftType.isUnsizedArray())
1793 {
1794 groupedTraverse(leftNode);
1795 mOut << ".size()";
1796 }
1797 else
1798 {
1799 uint32_t maxSize;
1800 if (leftType.isArray())
1801 {
1802 maxSize = leftType.getOutermostArraySize() - 1;
1803 }
1804 else
1805 {
1806 maxSize = leftType.getNominalSize() - 1;
1807 }
1808 mOut << maxSize;
1809 }
1810 mOut << ")";
1811 }
1812 mOut << "]";
1813 }
1814 break;
1815
1816 default:
1817 {
1818 const TType &resultType = binaryNode->getType();
1819 const TType &leftType = leftNode.getType();
1820 const TType &rightType = rightNode.getType();
1821
1822 if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1823 {
1824 groupedTraverse(leftNode);
1825 if (op != TOperator::EOpComma)
1826 {
1827 mOut << " ";
1828 }
1829 else
1830 {
1831 emitClosingPointerParen();
1832 }
1833 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1834 groupedTraverse(rightNode);
1835 }
1836 else
1837 {
1838 emitClosingPointerParen();
1839 mOut << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1840 leftNode.traverse(this);
1841 mOut << ", ";
1842 rightNode.traverse(this);
1843 mOut << ")";
1844 }
1845 }
1846 }
1847
1848 return false;
1849 }
1850
IsPostfix(TOperator op)1851 static bool IsPostfix(TOperator op)
1852 {
1853 switch (op)
1854 {
1855 case TOperator::EOpPostIncrement:
1856 case TOperator::EOpPostDecrement:
1857 return true;
1858
1859 default:
1860 return false;
1861 }
1862 }
1863
visitUnary(Visit,TIntermUnary * unaryNode)1864 bool GenMetalTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1865 {
1866 const TOperator op = unaryNode->getOp();
1867 const TType &resultType = unaryNode->getType();
1868
1869 TIntermTyped &arg = *unaryNode->getOperand();
1870 const TType &argType = arg.getType();
1871
1872 if (op == TOperator::EOpIsnan || op == TOperator::EOpIsinf)
1873 {
1874 mtl::getTranslatorMetalReflection(&mCompiler)->hasIsnanOrIsinf = true;
1875 }
1876
1877 const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1878
1879 if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1880 {
1881 const bool postfix = IsPostfix(op);
1882 if (!postfix)
1883 {
1884 mOut << name;
1885 }
1886 groupedTraverse(arg);
1887 if (postfix)
1888 {
1889 mOut << name;
1890 }
1891 }
1892 else
1893 {
1894 mOut << name << "(";
1895 arg.traverse(this);
1896 mOut << ")";
1897 }
1898
1899 return false;
1900 }
1901
visitTernary(Visit,TIntermTernary * conditionalNode)1902 bool GenMetalTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1903 {
1904 groupedTraverse(*conditionalNode->getCondition());
1905 mOut << " ? ";
1906 groupedTraverse(*conditionalNode->getTrueExpression());
1907 mOut << " : ";
1908 groupedTraverse(*conditionalNode->getFalseExpression());
1909
1910 return false;
1911 }
1912
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1913 bool GenMetalTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1914 {
1915 TIntermTyped &condNode = *ifThenElseNode->getCondition();
1916 TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1917 TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1918
1919 emitIndentation();
1920 mOut << "if (";
1921 condNode.traverse(this);
1922 mOut << ")";
1923
1924 if (thenNode)
1925 {
1926 mOut << "\n";
1927 thenNode->traverse(this);
1928 }
1929 else
1930 {
1931 mOut << " {}";
1932 }
1933
1934 if (elseNode)
1935 {
1936 mOut << "\n";
1937 emitIndentation();
1938 mOut << "else\n";
1939 elseNode->traverse(this);
1940 }
1941 else
1942 {
1943 // Always emit "else" even when empty block to avoid nested if-stmt issues.
1944 mOut << " else {}";
1945 }
1946
1947 return false;
1948 }
1949
visitSwitch(Visit,TIntermSwitch * switchNode)1950 bool GenMetalTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1951 {
1952 emitIndentation();
1953 mOut << "switch (";
1954 switchNode->getInit()->traverse(this);
1955 mOut << ")\n";
1956
1957 ASSERT(!mParentIsSwitch);
1958 mParentIsSwitch = true;
1959 switchNode->getStatementList()->traverse(this);
1960 mParentIsSwitch = false;
1961
1962 return false;
1963 }
1964
visitCase(Visit,TIntermCase * caseNode)1965 bool GenMetalTraverser::visitCase(Visit, TIntermCase *caseNode)
1966 {
1967 emitIndentation();
1968
1969 if (caseNode->hasCondition())
1970 {
1971 TIntermTyped *condExpr = caseNode->getCondition();
1972 mOut << "case ";
1973 condExpr->traverse(this);
1974 mOut << ":";
1975 }
1976 else
1977 {
1978 mOut << "default:\n";
1979 }
1980
1981 return false;
1982 }
1983
emitFunctionSignature(const TFunction & func)1984 void GenMetalTraverser::emitFunctionSignature(const TFunction &func)
1985 {
1986 const bool isMain = func.isMain();
1987
1988 emitFunctionReturn(func);
1989
1990 mOut << " ";
1991 emitNameOf(func);
1992 if (isMain)
1993 {
1994 mOut << "0";
1995 }
1996 mOut << "(";
1997
1998 bool emitComma = false;
1999 const size_t paramCount = func.getParamCount();
2000 for (size_t i = 0; i < paramCount; ++i)
2001 {
2002 if (emitComma)
2003 {
2004 mOut << ", ";
2005 }
2006 emitComma = true;
2007
2008 const TVariable ¶m = *func.getParam(i);
2009 emitFunctionParameter(func, param);
2010 }
2011
2012 if (isTraversingVertexMain)
2013 {
2014 mOut << " @@XFB-Bindings@@ ";
2015 }
2016
2017 mOut << ")";
2018 }
2019
emitFunctionReturn(const TFunction & func)2020 void GenMetalTraverser::emitFunctionReturn(const TFunction &func)
2021 {
2022 const bool isMain = func.isMain();
2023 bool isVertexMain = false;
2024 const TType &returnType = func.getReturnType();
2025 if (isMain)
2026 {
2027 const TStructure *structure = returnType.getStruct();
2028 if (mPipelineStructs.fragmentOut.matches(*structure))
2029 {
2030 if (mCompiler.isEarlyFragmentTestsSpecified())
2031 {
2032 mOut << "[[early_fragment_tests]]\n";
2033 }
2034 mOut << "fragment ";
2035 }
2036 else if (mPipelineStructs.vertexOut.matches(*structure))
2037 {
2038 ASSERT(structure != nullptr);
2039 mOut << "vertex __VERTEX_OUT(";
2040 isVertexMain = true;
2041 }
2042 else
2043 {
2044 UNIMPLEMENTED();
2045 }
2046 }
2047 emitType(returnType, EmitTypeConfig());
2048 if (isVertexMain)
2049 mOut << ") ";
2050 }
2051
emitFunctionParameter(const TFunction & func,const TVariable & param)2052 void GenMetalTraverser::emitFunctionParameter(const TFunction &func, const TVariable ¶m)
2053 {
2054 const bool isMain = func.isMain();
2055
2056 const TType &type = param.getType();
2057 const TStructure *structure = type.getStruct();
2058
2059 EmitVariableDeclarationConfig evdConfig;
2060 evdConfig.isParameter = true;
2061 evdConfig.disableStructSpecifier = true; // It's invalid to declare a struct in a function arg.
2062 evdConfig.isMainParameter = isMain;
2063 evdConfig.emitPostQualifier = isMain;
2064 evdConfig.isUBO = mSymbolEnv.isUBO(param);
2065 evdConfig.isPointer = mSymbolEnv.isPointer(param);
2066 evdConfig.isReference = mSymbolEnv.isReference(param);
2067 emitVariableDeclaration(VarDecl(param), evdConfig);
2068
2069 if (isMain)
2070 {
2071 TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
2072 if (structure)
2073 {
2074 if (mPipelineStructs.fragmentIn.matches(*structure) ||
2075 mPipelineStructs.vertexIn.matches(*structure))
2076 {
2077 mOut << " [[stage_in]]";
2078 }
2079 else if (mPipelineStructs.angleUniforms.matches(*structure))
2080 {
2081 mOut << " [[buffer(" << mDriverUniformsBindingIndex << ")]]";
2082 }
2083 else if (mPipelineStructs.uniformBuffers.matches(*structure))
2084 {
2085 mOut << " [[buffer(" << mUBOArgumentBufferBindingIndex << ")]]";
2086 reflection->hasUBOs = true;
2087 }
2088 else if (mPipelineStructs.userUniforms.matches(*structure))
2089 {
2090 mOut << " [[buffer(" << mMainUniformBufferIndex << ")]]";
2091 reflection->addUserUniformBufferBinding(param.name().data(),
2092 mMainUniformBufferIndex);
2093 mMainUniformBufferIndex += type.getArraySizeProduct();
2094 }
2095 else if (structure->name() == "metal::sampler")
2096 {
2097 mOut << " [[sampler(" << (mMainSamplerIndex) << ")]]";
2098 const std::string originalName =
2099 reflection->getOriginalName(param.uniqueId().get());
2100 reflection->addSamplerBinding(originalName, mMainSamplerIndex);
2101 mMainSamplerIndex += type.getArraySizeProduct();
2102 }
2103 }
2104 else if (IsSampler(type.getBasicType()))
2105 {
2106 mOut << " [[texture(" << (mMainTextureIndex) << ")]]";
2107 const std::string originalName = reflection->getOriginalName(param.uniqueId().get());
2108 reflection->addTextureBinding(originalName, mMainTextureIndex);
2109 mMainTextureIndex += type.getArraySizeProduct();
2110 }
2111 else if (Name(param) == Pipeline{Pipeline::Type::InstanceId, nullptr}.getStructInstanceName(
2112 Pipeline::Variant::Modified))
2113 {
2114 mOut << " [[instance_id]]";
2115 }
2116 else if (Name(param) == kBaseInstanceName)
2117 {
2118 mOut << " [[base_instance]]";
2119 }
2120 }
2121 }
2122
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)2123 void GenMetalTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
2124 {
2125 const TFunction &func = *funcProtoNode->getFunction();
2126
2127 emitIndentation();
2128 emitFunctionSignature(func);
2129 }
2130
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)2131 bool GenMetalTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
2132 {
2133 const TFunction &func = *funcDefNode->getFunction();
2134 TIntermBlock &body = *funcDefNode->getBody();
2135 if (func.isMain())
2136 {
2137 const TType &returnType = func.getReturnType();
2138 const TStructure *structure = returnType.getStruct();
2139 isTraversingVertexMain = (mPipelineStructs.vertexOut.matches(*structure)) &&
2140 mCompiler.getShaderType() == GL_VERTEX_SHADER;
2141 }
2142 emitIndentation();
2143 emitFunctionSignature(func);
2144 mOut << "\n";
2145 body.traverse(this);
2146 if (isTraversingVertexMain)
2147 {
2148 isTraversingVertexMain = false;
2149 }
2150 return false;
2151 }
2152
BuildFuncToName()2153 GenMetalTraverser::FuncToName GenMetalTraverser::BuildFuncToName()
2154 {
2155 FuncToName map;
2156
2157 auto putAngle = [&](const char *nameStr) {
2158 const ImmutableString name(nameStr);
2159 ASSERT(map.find(name) == map.end());
2160 map[name] = Name(nameStr, SymbolType::AngleInternal);
2161 };
2162
2163 putAngle("texelFetch");
2164 putAngle("texelFetchOffset");
2165 putAngle("texture");
2166 putAngle("texture2D");
2167 putAngle("texture2DGradEXT");
2168 putAngle("texture2DLod");
2169 putAngle("texture2DLodEXT");
2170 putAngle("texture2DProj");
2171 putAngle("texture2DProjGradEXT");
2172 putAngle("texture2DProjLod");
2173 putAngle("texture2DProjLodEXT");
2174 putAngle("texture3D");
2175 putAngle("texture3DLod");
2176 putAngle("texture3DProj");
2177 putAngle("texture3DProjLod");
2178 putAngle("textureCube");
2179 putAngle("textureCubeGradEXT");
2180 putAngle("textureCubeLod");
2181 putAngle("textureCubeLodEXT");
2182 putAngle("textureGrad");
2183 putAngle("textureGradOffset");
2184 putAngle("textureLod");
2185 putAngle("textureLodOffset");
2186 putAngle("textureOffset");
2187 putAngle("textureProj");
2188 putAngle("textureProjGrad");
2189 putAngle("textureProjGradOffset");
2190 putAngle("textureProjLod");
2191 putAngle("textureProjLodOffset");
2192 putAngle("textureProjOffset");
2193 putAngle("textureSize");
2194 putAngle("imageLoad");
2195 putAngle("imageStore");
2196 putAngle("memoryBarrierImage");
2197
2198 return map;
2199 }
2200
visitAggregate(Visit,TIntermAggregate * aggregateNode)2201 bool GenMetalTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
2202 {
2203 const TIntermSequence &args = *aggregateNode->getSequence();
2204
2205 auto emitArgList = [&](const char *open, const char *close) {
2206 mOut << open;
2207
2208 bool emitComma = false;
2209 for (TIntermNode *arg : args)
2210 {
2211 if (emitComma)
2212 {
2213 emitClosingPointerParen();
2214 mOut << ", ";
2215 }
2216 emitComma = true;
2217 arg->traverse(this);
2218 }
2219
2220 mOut << close;
2221 };
2222
2223 const TType &retType = aggregateNode->getType();
2224
2225 if (aggregateNode->isConstructor())
2226 {
2227 const bool isStandalone = getParentNode()->getAsBlock();
2228 if (isStandalone)
2229 {
2230 // Prevent constructor from being interpreted as a declaration by wrapping in parens.
2231 // This can happen if given something like:
2232 // int(symbol); // <- This will be treated like `int symbol;`... don't want that.
2233 // So instead emit:
2234 // (int(symbol));
2235 mOut << "(";
2236 }
2237
2238 const EmitTypeConfig etConfig;
2239
2240 if (retType.isArray())
2241 {
2242 emitType(retType, etConfig);
2243 emitArgList("{", "}");
2244 }
2245 else if (retType.getStruct())
2246 {
2247 emitType(retType, etConfig);
2248 emitArgList("{", "}");
2249 }
2250 else
2251 {
2252 emitType(retType, etConfig);
2253 emitArgList("(", ")");
2254 }
2255
2256 if (isStandalone)
2257 {
2258 mOut << ")";
2259 }
2260
2261 return false;
2262 }
2263 else
2264 {
2265 const TOperator op = aggregateNode->getOp();
2266 switch (op)
2267 {
2268 case TOperator::EOpCallFunctionInAST:
2269 case TOperator::EOpCallInternalRawFunction:
2270 {
2271 const TFunction &func = *aggregateNode->getFunction();
2272 emitNameOf(func);
2273 //'@' symbol in name specifices a macro substitution marker.
2274 if (!func.name().contains("@"))
2275 {
2276 emitArgList("(", ")");
2277 }
2278 else
2279 {
2280 mTemporarilyDisableSemicolon =
2281 true; // Disable semicolon for macro substitution.
2282 }
2283 return false;
2284 }
2285
2286 default:
2287 {
2288 auto getArgType = [&](size_t index) -> const TType * {
2289 if (index < args.size())
2290 {
2291 TIntermTyped *arg = args[index]->getAsTyped();
2292 ASSERT(arg);
2293 return &arg->getType();
2294 }
2295 return nullptr;
2296 };
2297
2298 const TType *argType0 = getArgType(0);
2299 const TType *argType1 = getArgType(1);
2300 const TType *argType2 = getArgType(2);
2301
2302 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
2303
2304 if (IsSymbolicOperator(op, retType, argType0, argType1))
2305 {
2306 switch (args.size())
2307 {
2308 case 1:
2309 {
2310 TIntermNode &operandNode = *aggregateNode->getChildNode(0);
2311 if (IsPostfix(op))
2312 {
2313 mOut << opName;
2314 groupedTraverse(operandNode);
2315 }
2316 else
2317 {
2318 groupedTraverse(operandNode);
2319 mOut << opName;
2320 }
2321 return false;
2322 }
2323
2324 case 2:
2325 {
2326 TIntermNode &leftNode = *aggregateNode->getChildNode(0);
2327 TIntermNode &rightNode = *aggregateNode->getChildNode(1);
2328 groupedTraverse(leftNode);
2329 mOut << " " << opName << " ";
2330 groupedTraverse(rightNode);
2331 return false;
2332 }
2333
2334 default:
2335 UNREACHABLE();
2336 return false;
2337 }
2338 }
2339 else if (opName == nullptr)
2340 {
2341 const TFunction &func = *aggregateNode->getFunction();
2342 auto it = mFuncToName.find(func.name());
2343 ASSERT(it != mFuncToName.end());
2344 EmitName(mOut, it->second);
2345 emitArgList("(", ")");
2346 return false;
2347 }
2348 else
2349 {
2350 mOut << opName;
2351 emitArgList("(", ")");
2352 return false;
2353 }
2354 }
2355 }
2356 }
2357 }
2358
emitOpenBrace()2359 void GenMetalTraverser::emitOpenBrace()
2360 {
2361 ASSERT(mIndentLevel >= 0);
2362
2363 emitIndentation();
2364 mOut << "{\n";
2365 ++mIndentLevel;
2366 }
2367
emitCloseBrace()2368 void GenMetalTraverser::emitCloseBrace()
2369 {
2370 ASSERT(mIndentLevel >= 1);
2371
2372 --mIndentLevel;
2373 emitIndentation();
2374 mOut << "}";
2375 }
2376
RequiresSemicolonTerminator(TIntermNode & node)2377 static bool RequiresSemicolonTerminator(TIntermNode &node)
2378 {
2379 if (node.getAsBlock())
2380 {
2381 return false;
2382 }
2383 if (node.getAsLoopNode())
2384 {
2385 return false;
2386 }
2387 if (node.getAsSwitchNode())
2388 {
2389 return false;
2390 }
2391 if (node.getAsIfElseNode())
2392 {
2393 return false;
2394 }
2395 if (node.getAsFunctionDefinition())
2396 {
2397 return false;
2398 }
2399 if (node.getAsCaseNode())
2400 {
2401 return false;
2402 }
2403
2404 return true;
2405 }
2406
NewlinePad(TIntermNode & node)2407 static bool NewlinePad(TIntermNode &node)
2408 {
2409 if (node.getAsFunctionDefinition())
2410 {
2411 return true;
2412 }
2413 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
2414 {
2415 ASSERT(declNode->getChildCount() == 1);
2416 TIntermNode &childNode = *declNode->getChildNode(0);
2417 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
2418 {
2419 const TVariable &var = symbolNode->variable();
2420 return var.getType().isStructSpecifier();
2421 }
2422 return false;
2423 }
2424 return false;
2425 }
2426
visitBlock(Visit,TIntermBlock * blockNode)2427 bool GenMetalTraverser::visitBlock(Visit, TIntermBlock *blockNode)
2428 {
2429 ASSERT(mIndentLevel >= -1);
2430 const bool isGlobalScope = mIndentLevel == -1;
2431 const bool parentIsSwitch = mParentIsSwitch;
2432 mParentIsSwitch = false;
2433
2434 if (isGlobalScope)
2435 {
2436 ++mIndentLevel;
2437 }
2438 else
2439 {
2440 emitOpenBrace();
2441 if (parentIsSwitch)
2442 {
2443 ++mIndentLevel;
2444 }
2445 }
2446
2447 TIntermNode *prevStmtNode = nullptr;
2448
2449 const size_t stmtCount = blockNode->getChildCount();
2450 for (size_t i = 0; i < stmtCount; ++i)
2451 {
2452 TIntermNode &stmtNode = *blockNode->getChildNode(i);
2453
2454 if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
2455 {
2456 mOut << "\n";
2457 }
2458 const bool isCase = stmtNode.getAsCaseNode();
2459 mIndentLevel -= isCase;
2460 emitIndentation();
2461 mIndentLevel += isCase;
2462 stmtNode.traverse(this);
2463 if (RequiresSemicolonTerminator(stmtNode) && !mTemporarilyDisableSemicolon)
2464 {
2465 mOut << ";";
2466 }
2467 mTemporarilyDisableSemicolon = false;
2468 mOut << "\n";
2469
2470 prevStmtNode = &stmtNode;
2471 }
2472
2473 if (isGlobalScope)
2474 {
2475 ASSERT(mIndentLevel == 0);
2476 --mIndentLevel;
2477 }
2478 else
2479 {
2480 if (parentIsSwitch)
2481 {
2482 ASSERT(mIndentLevel >= 1);
2483 --mIndentLevel;
2484 }
2485 emitCloseBrace();
2486 mParentIsSwitch = parentIsSwitch;
2487 }
2488
2489 return false;
2490 }
2491
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)2492 bool GenMetalTraverser::visitGlobalQualifierDeclaration(Visit, TIntermGlobalQualifierDeclaration *)
2493 {
2494 return false;
2495 }
2496
visitDeclaration(Visit,TIntermDeclaration * declNode)2497 bool GenMetalTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
2498 {
2499 ASSERT(declNode->getChildCount() == 1);
2500 TIntermNode &node = *declNode->getChildNode(0);
2501
2502 EmitVariableDeclarationConfig evdConfig;
2503
2504 if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
2505 {
2506 const TVariable &var = symbolNode->variable();
2507 emitVariableDeclaration(VarDecl(var), evdConfig);
2508 if (var.getType().isArray() && var.getType().getQualifier() == EvqTemporary)
2509 {
2510 // The translator frontend injects a loop-based init for user arrays when the source
2511 // shader is using ESSL 1.00. Some Metal drivers may fail to access elements of such
2512 // arrays at runtime depending on the array size. An empty literal initializer added
2513 // to the generated MSL bypasses the issue. The frontend may be further optimized to
2514 // skip the loop-based init when targeting MSL.
2515 mOut << "{}";
2516 }
2517 }
2518 else if (TIntermBinary *initNode = node.getAsBinaryNode())
2519 {
2520 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
2521 TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
2522 TIntermTyped *valueNode = initNode->getRight()->getAsTyped();
2523 ASSERT(leftSymbolNode && valueNode);
2524
2525 if (getRootNode() == getParentBlock())
2526 {
2527 // DeferGlobalInitializers should have turned non-const global initializers into
2528 // deferred initializers. Note that variables marked as EvqGlobal can be treated as
2529 // EvqConst in some ANGLE code but not actually have their qualifier actually changed to
2530 // EvqConst. Thus just assume all EvqGlobal are actually EvqConst for all code run after
2531 // DeferGlobalInitializers.
2532 mOut << "constant ";
2533 }
2534
2535 const TVariable &var = leftSymbolNode->variable();
2536 const Name varName(var);
2537
2538 if (ExpressionContainsName(varName, *valueNode))
2539 {
2540 mRenamedSymbols[&var] = mIdGen.createNewName(varName);
2541 }
2542
2543 emitVariableDeclaration(VarDecl(var), evdConfig);
2544 mOut << " = ";
2545 groupedTraverse(*valueNode);
2546 }
2547 else
2548 {
2549 UNREACHABLE();
2550 }
2551
2552 return false;
2553 }
2554
visitLoop(Visit,TIntermLoop * loopNode)2555 bool GenMetalTraverser::visitLoop(Visit, TIntermLoop *loopNode)
2556 {
2557 const TLoopType loopType = loopNode->getType();
2558
2559 switch (loopType)
2560 {
2561 case TLoopType::ELoopFor:
2562 return visitForLoop(loopNode);
2563 case TLoopType::ELoopWhile:
2564 return visitWhileLoop(loopNode);
2565 case TLoopType::ELoopDoWhile:
2566 return visitDoWhileLoop(loopNode);
2567 }
2568 }
2569
visitForLoop(TIntermLoop * loopNode)2570 bool GenMetalTraverser::visitForLoop(TIntermLoop *loopNode)
2571 {
2572 ASSERT(loopNode->getType() == TLoopType::ELoopFor);
2573
2574 TIntermNode *initNode = loopNode->getInit();
2575 TIntermTyped *condNode = loopNode->getCondition();
2576 TIntermTyped *exprNode = loopNode->getExpression();
2577
2578 mOut << "for (";
2579
2580 if (initNode)
2581 {
2582 initNode->traverse(this);
2583 }
2584 else
2585 {
2586 mOut << " ";
2587 }
2588
2589 mOut << "; ";
2590
2591 if (condNode)
2592 {
2593 condNode->traverse(this);
2594 }
2595
2596 mOut << "; ";
2597
2598 if (exprNode)
2599 {
2600 exprNode->traverse(this);
2601 }
2602
2603 mOut << ")\n";
2604
2605 emitLoopBody(loopNode->getBody());
2606
2607 return false;
2608 }
2609
visitWhileLoop(TIntermLoop * loopNode)2610 bool GenMetalTraverser::visitWhileLoop(TIntermLoop *loopNode)
2611 {
2612 ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
2613
2614 TIntermNode *initNode = loopNode->getInit();
2615 TIntermTyped *condNode = loopNode->getCondition();
2616 TIntermTyped *exprNode = loopNode->getExpression();
2617 ASSERT(condNode);
2618 ASSERT(!initNode && !exprNode);
2619
2620 emitIndentation();
2621 mOut << "while (";
2622 condNode->traverse(this);
2623 mOut << ")\n";
2624 emitLoopBody(loopNode->getBody());
2625
2626 return false;
2627 }
2628
visitDoWhileLoop(TIntermLoop * loopNode)2629 bool GenMetalTraverser::visitDoWhileLoop(TIntermLoop *loopNode)
2630 {
2631 ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
2632
2633 TIntermNode *initNode = loopNode->getInit();
2634 TIntermTyped *condNode = loopNode->getCondition();
2635 TIntermTyped *exprNode = loopNode->getExpression();
2636 ASSERT(condNode);
2637 ASSERT(!initNode && !exprNode);
2638
2639 emitIndentation();
2640 mOut << "do\n";
2641 emitLoopBody(loopNode->getBody());
2642 mOut << "\n";
2643 emitIndentation();
2644 mOut << "while (";
2645 condNode->traverse(this);
2646 mOut << ");";
2647
2648 return false;
2649 }
2650
visitBranch(Visit,TIntermBranch * branchNode)2651 bool GenMetalTraverser::visitBranch(Visit, TIntermBranch *branchNode)
2652 {
2653 const TOperator flowOp = branchNode->getFlowOp();
2654 TIntermTyped *exprNode = branchNode->getExpression();
2655
2656 emitIndentation();
2657
2658 switch (flowOp)
2659 {
2660 case TOperator::EOpKill:
2661 {
2662 ASSERT(exprNode == nullptr);
2663 mOut << "metal::discard_fragment()";
2664 }
2665 break;
2666
2667 case TOperator::EOpReturn:
2668 {
2669 if (isTraversingVertexMain)
2670 {
2671 mOut << "#if TRANSFORM_FEEDBACK_ENABLED\n";
2672 emitIndentation();
2673 mOut << "return;\n";
2674 emitIndentation();
2675 mOut << "#else\n";
2676 emitIndentation();
2677 }
2678 mOut << "return";
2679 if (exprNode)
2680 {
2681 mOut << " ";
2682 exprNode->traverse(this);
2683 mOut << ";";
2684 }
2685 if (isTraversingVertexMain)
2686 {
2687 mOut << "\n";
2688 emitIndentation();
2689 mOut << "#endif\n";
2690 mTemporarilyDisableSemicolon = true;
2691 }
2692 }
2693 break;
2694
2695 case TOperator::EOpBreak:
2696 {
2697 ASSERT(exprNode == nullptr);
2698 mOut << "break";
2699 }
2700 break;
2701
2702 case TOperator::EOpContinue:
2703 {
2704 ASSERT(exprNode == nullptr);
2705 mOut << "continue";
2706 }
2707 break;
2708
2709 default:
2710 {
2711 UNREACHABLE();
2712 }
2713 }
2714
2715 return false;
2716 }
2717
2718 static size_t emitMetalCallCount = 0;
2719
EmitMetal(TCompiler & compiler,TIntermBlock & root,IdGen & idGen,const PipelineStructs & pipelineStructs,SymbolEnv & symbolEnv,const ProgramPreludeConfig & ppc,const ShCompileOptions & compileOptions)2720 bool sh::EmitMetal(TCompiler &compiler,
2721 TIntermBlock &root,
2722 IdGen &idGen,
2723 const PipelineStructs &pipelineStructs,
2724 SymbolEnv &symbolEnv,
2725 const ProgramPreludeConfig &ppc,
2726 const ShCompileOptions &compileOptions)
2727 {
2728 TInfoSinkBase &out = compiler.getInfoSink().obj;
2729
2730 {
2731 ++emitMetalCallCount;
2732 std::string filenameProto = angle::GetEnvironmentVar("GMD_FIXED_EMIT");
2733 if (!filenameProto.empty())
2734 {
2735 if (filenameProto != "/dev/null")
2736 {
2737 auto tryOpen = [&](char const *ext) {
2738 auto filename = filenameProto;
2739 filename += std::to_string(emitMetalCallCount);
2740 filename += ".";
2741 filename += ext;
2742 return fopen(filename.c_str(), "rb");
2743 };
2744 FILE *file = tryOpen("metal");
2745 if (!file)
2746 {
2747 file = tryOpen("cpp");
2748 }
2749 ASSERT(file);
2750
2751 fseek(file, 0, SEEK_END);
2752 size_t fileSize = ftell(file);
2753 fseek(file, 0, SEEK_SET);
2754
2755 std::vector<char> buff;
2756 buff.resize(fileSize + 1);
2757 fread(buff.data(), fileSize, 1, file);
2758 buff.back() = '\0';
2759
2760 fclose(file);
2761
2762 out << buff.data();
2763 }
2764
2765 return true;
2766 }
2767 }
2768
2769 out << "\n\n";
2770
2771 if (!EmitProgramPrelude(root, out, ppc))
2772 {
2773 return false;
2774 }
2775
2776 {
2777 #if defined(ANGLE_ENABLE_ASSERTS)
2778 DebugSink outWrapper(out, angle::GetBoolEnvironmentVar("GMD_STDOUT"));
2779 outWrapper.watch(angle::GetEnvironmentVar("GMD_WATCH_STRING"));
2780 #else
2781 TInfoSinkBase &outWrapper = out;
2782 #endif
2783 GenMetalTraverser gen(compiler, outWrapper, idGen, pipelineStructs, symbolEnv,
2784 compileOptions);
2785 root.traverse(&gen);
2786 }
2787
2788 out << "\n";
2789
2790 return true;
2791 }
2792