xref: /aosp_15_r20/external/angle/src/compiler/translator/wgsl/TranslatorWGSL.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/wgsl/TranslatorWGSL.h"
8 
9 #include <iostream>
10 #include <variant>
11 
12 #include "GLSLANG/ShaderLang.h"
13 #include "common/log_utils.h"
14 #include "compiler/translator/BaseTypes.h"
15 #include "compiler/translator/Common.h"
16 #include "compiler/translator/Diagnostics.h"
17 #include "compiler/translator/ImmutableString.h"
18 #include "compiler/translator/ImmutableStringBuilder.h"
19 #include "compiler/translator/InfoSink.h"
20 #include "compiler/translator/IntermNode.h"
21 #include "compiler/translator/OutputTree.h"
22 #include "compiler/translator/StaticType.h"
23 #include "compiler/translator/SymbolUniqueId.h"
24 #include "compiler/translator/Types.h"
25 #include "compiler/translator/tree_util/BuiltIn_autogen.h"
26 #include "compiler/translator/tree_util/FindMain.h"
27 #include "compiler/translator/tree_util/IntermNode_util.h"
28 #include "compiler/translator/tree_util/IntermTraverse.h"
29 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
30 #include "compiler/translator/wgsl/OutputUniformBlocks.h"
31 #include "compiler/translator/wgsl/RewritePipelineVariables.h"
32 #include "compiler/translator/wgsl/Utils.h"
33 
34 namespace sh
35 {
36 namespace
37 {
38 
39 constexpr bool kOutputTreeBeforeTranslation = false;
40 constexpr bool kOutputTranslatedShader      = false;
41 
42 struct VarDecl
43 {
44     const SymbolType symbolType = SymbolType::Empty;
45     const ImmutableString &symbolName;
46     const TType &type;
47 };
48 
IsDefaultUniform(const TType & type)49 bool IsDefaultUniform(const TType &type)
50 {
51     return type.getQualifier() == EvqUniform && type.getInterfaceBlock() == nullptr &&
52            !IsOpaqueType(type.getBasicType());
53 }
54 
55 // When emitting a list of statements, this determines whether a semicolon follows the statement.
RequiresSemicolonTerminator(TIntermNode & node)56 bool RequiresSemicolonTerminator(TIntermNode &node)
57 {
58     if (node.getAsBlock())
59     {
60         return false;
61     }
62     if (node.getAsLoopNode())
63     {
64         return false;
65     }
66     if (node.getAsSwitchNode())
67     {
68         return false;
69     }
70     if (node.getAsIfElseNode())
71     {
72         return false;
73     }
74     if (node.getAsFunctionDefinition())
75     {
76         return false;
77     }
78     if (node.getAsCaseNode())
79     {
80         return false;
81     }
82 
83     return true;
84 }
85 
86 // For pretty formatting of the resulting WGSL text.
NewlinePad(TIntermNode & node)87 bool NewlinePad(TIntermNode &node)
88 {
89     if (node.getAsFunctionDefinition())
90     {
91         return true;
92     }
93     if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
94     {
95         ASSERT(declNode->getChildCount() == 1);
96         TIntermNode &childNode = *declNode->getChildNode(0);
97         if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
98         {
99             const TVariable &var = symbolNode->variable();
100             return var.getType().isStructSpecifier();
101         }
102         return false;
103     }
104     return false;
105 }
106 
107 // A traverser that generates WGSL as it walks the AST.
108 class OutputWGSLTraverser : public TIntermTraverser
109 {
110   public:
111     OutputWGSLTraverser(TCompiler *compiler,
112                         RewritePipelineVarOutput *rewritePipelineVarOutput,
113                         UniformBlockMetadata *uniformBlockMetadata);
114     ~OutputWGSLTraverser() override;
115 
116   protected:
117     void visitSymbol(TIntermSymbol *node) override;
118     void visitConstantUnion(TIntermConstantUnion *node) override;
119     bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
120     bool visitBinary(Visit visit, TIntermBinary *node) override;
121     bool visitUnary(Visit visit, TIntermUnary *node) override;
122     bool visitTernary(Visit visit, TIntermTernary *node) override;
123     bool visitIfElse(Visit visit, TIntermIfElse *node) override;
124     bool visitSwitch(Visit visit, TIntermSwitch *node) override;
125     bool visitCase(Visit visit, TIntermCase *node) override;
126     void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
127     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
128     bool visitAggregate(Visit visit, TIntermAggregate *node) override;
129     bool visitBlock(Visit visit, TIntermBlock *node) override;
130     bool visitGlobalQualifierDeclaration(Visit visit,
131                                          TIntermGlobalQualifierDeclaration *node) override;
132     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
133     bool visitLoop(Visit visit, TIntermLoop *node) override;
134     bool visitBranch(Visit visit, TIntermBranch *node) override;
135     void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
136 
137   private:
138     struct EmitVariableDeclarationConfig
139     {
140         bool isParameter            = false;
141         bool disableStructSpecifier = false;
142         bool needsVar               = false;
143         bool isGlobalScope          = false;
144     };
145 
146     void groupedTraverse(TIntermNode &node);
147     void emitNameOf(const VarDecl &decl);
148     void emitBareTypeName(const TType &type);
149     void emitType(const TType &type);
150     void emitSingleConstant(const TConstantUnion *const constUnion);
151     const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
152                                                  const size_t size);
153     const TConstantUnion *emitConstantUnion(const TType &type,
154                                             const TConstantUnion *constUnionBegin);
155     const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
156     void emitIndentation();
157     void emitOpenBrace();
158     void emitCloseBrace();
159     bool emitBlock(TSpan<TIntermNode *> nodes);
160     void emitFunctionSignature(const TFunction &func);
161     void emitFunctionReturn(const TFunction &func);
162     void emitFunctionParameter(const TFunction &func, const TVariable &param);
163     void emitStructDeclaration(const TType &type);
164     void emitVariableDeclaration(const VarDecl &decl,
165                                  const EmitVariableDeclarationConfig &evdConfig);
166     void emitArrayIndex(TIntermTyped &leftNode, TIntermTyped &rightNode);
167 
168     bool emitForLoop(TIntermLoop *);
169     bool emitWhileLoop(TIntermLoop *);
170     bool emulateDoWhileLoop(TIntermLoop *);
171 
172     TInfoSinkBase &mSink;
173     RewritePipelineVarOutput *mRewritePipelineVarOutput;
174     UniformBlockMetadata *mUniformBlockMetadata;
175 
176     int mIndentLevel        = -1;
177     int mLastIndentationPos = -1;
178 };
179 
OutputWGSLTraverser(TCompiler * compiler,RewritePipelineVarOutput * rewritePipelineVarOutput,UniformBlockMetadata * uniformBlockMetadata)180 OutputWGSLTraverser::OutputWGSLTraverser(TCompiler *compiler,
181                                          RewritePipelineVarOutput *rewritePipelineVarOutput,
182                                          UniformBlockMetadata *uniformBlockMetadata)
183     : TIntermTraverser(true, false, false),
184       mSink(compiler->getInfoSink().obj),
185       mRewritePipelineVarOutput(rewritePipelineVarOutput),
186       mUniformBlockMetadata(uniformBlockMetadata)
187 {}
188 
189 OutputWGSLTraverser::~OutputWGSLTraverser() = default;
190 
groupedTraverse(TIntermNode & node)191 void OutputWGSLTraverser::groupedTraverse(TIntermNode &node)
192 {
193     // TODO(anglebug.com/42267100): to make generated code more readable, do not always
194     // emit parentheses like WGSL is some Lisp dialect.
195     const bool emitParens = true;
196 
197     if (emitParens)
198     {
199         mSink << "(";
200     }
201 
202     node.traverse(this);
203 
204     if (emitParens)
205     {
206         mSink << ")";
207     }
208 }
209 
emitNameOf(const VarDecl & decl)210 void OutputWGSLTraverser::emitNameOf(const VarDecl &decl)
211 {
212     WriteNameOf(mSink, decl.symbolType, decl.symbolName);
213 }
214 
emitIndentation()215 void OutputWGSLTraverser::emitIndentation()
216 {
217     ASSERT(mIndentLevel >= 0);
218 
219     if (mLastIndentationPos == mSink.size())
220     {
221         return;  // Line is already indented.
222     }
223 
224     for (int i = 0; i < mIndentLevel; ++i)
225     {
226         mSink << "  ";
227     }
228 
229     mLastIndentationPos = mSink.size();
230 }
231 
emitOpenBrace()232 void OutputWGSLTraverser::emitOpenBrace()
233 {
234     ASSERT(mIndentLevel >= 0);
235 
236     emitIndentation();
237     mSink << "{\n";
238     ++mIndentLevel;
239 }
240 
emitCloseBrace()241 void OutputWGSLTraverser::emitCloseBrace()
242 {
243     ASSERT(mIndentLevel >= 1);
244 
245     --mIndentLevel;
246     emitIndentation();
247     mSink << "}";
248 }
249 
visitSymbol(TIntermSymbol * symbolNode)250 void OutputWGSLTraverser::visitSymbol(TIntermSymbol *symbolNode)
251 {
252 
253     const TVariable &var = symbolNode->variable();
254     const TType &type    = var.getType();
255     ASSERT(var.symbolType() != SymbolType::Empty);
256 
257     if (type.getBasicType() == TBasicType::EbtVoid)
258     {
259         UNREACHABLE();
260     }
261     else
262     {
263         // Accesses of pipeline variables should be rewritten as struct accesses.
264         if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()))
265         {
266             mSink << kBuiltinInputStructName << "." << var.name();
267         }
268         else if (mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
269         {
270             mSink << kBuiltinOutputStructName << "." << var.name();
271         }
272         // Accesses of basic uniforms need to be converted to struct accesses.
273         else if (IsDefaultUniform(type))
274         {
275             mSink << kDefaultUniformBlockVarName << "." << var.name();
276         }
277         else
278         {
279             WriteNameOf(mSink, var);
280         }
281 
282         if (var.symbolType() == SymbolType::BuiltIn)
283         {
284             ASSERT(mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
285                    mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()) ||
286                    var.uniqueId() == BuiltInId::gl_DepthRange);
287             // TODO(anglebug.com/42267100): support gl_DepthRange.
288             // Match the name of the struct field in `mRewritePipelineVarOutput`.
289             mSink << "_";
290         }
291     }
292 }
293 
emitSingleConstant(const TConstantUnion * const constUnion)294 void OutputWGSLTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
295 {
296     switch (constUnion->getType())
297     {
298         case TBasicType::EbtBool:
299         {
300             mSink << (constUnion->getBConst() ? "true" : "false");
301         }
302         break;
303 
304         case TBasicType::EbtFloat:
305         {
306             float value = constUnion->getFConst();
307             if (std::isnan(value))
308             {
309                 UNIMPLEMENTED();
310                 // TODO(anglebug.com/42267100): this is not a valid constant in WGPU.
311                 // You can't even do something like bitcast<f32>(0xffffffffu).
312                 // The WGSL compiler still complains. I think this is because
313                 // WGSL supports implementations compiling with -ffastmath and
314                 // therefore nans and infinities are assumed to not exist.
315                 // See also https://github.com/gpuweb/gpuweb/issues/3749.
316                 mSink << "NAN_INVALID";
317             }
318             else if (std::isinf(value))
319             {
320                 UNIMPLEMENTED();
321                 // see above.
322                 mSink << "INFINITY_INVALID";
323             }
324             else
325             {
326                 mSink << value << "f";
327             }
328         }
329         break;
330 
331         case TBasicType::EbtInt:
332         {
333             mSink << constUnion->getIConst() << "i";
334         }
335         break;
336 
337         case TBasicType::EbtUInt:
338         {
339             mSink << constUnion->getUConst() << "u";
340         }
341         break;
342 
343         default:
344         {
345             UNIMPLEMENTED();
346         }
347     }
348 }
349 
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)350 const TConstantUnion *OutputWGSLTraverser::emitConstantUnionArray(
351     const TConstantUnion *const constUnion,
352     const size_t size)
353 {
354     const TConstantUnion *constUnionIterated = constUnion;
355     for (size_t i = 0; i < size; i++, constUnionIterated++)
356     {
357         emitSingleConstant(constUnionIterated);
358 
359         if (i != size - 1)
360         {
361             mSink << ", ";
362         }
363     }
364     return constUnionIterated;
365 }
366 
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)367 const TConstantUnion *OutputWGSLTraverser::emitConstantUnion(const TType &type,
368                                                              const TConstantUnion *constUnionBegin)
369 {
370     const TConstantUnion *constUnionCurr = constUnionBegin;
371     const TStructure *structure          = type.getStruct();
372     if (structure)
373     {
374         emitType(type);
375         // Structs are constructed with parentheses in WGSL.
376         mSink << "(";
377         // Emit the constructor parameters. Both GLSL and WGSL require there to be the same number
378         // of parameters as struct fields.
379         const TFieldList &fields = structure->fields();
380         for (size_t i = 0; i < fields.size(); ++i)
381         {
382             const TType *fieldType = fields[i]->type();
383             constUnionCurr         = emitConstantUnion(*fieldType, constUnionCurr);
384             if (i != fields.size() - 1)
385             {
386                 mSink << ", ";
387             }
388         }
389         mSink << ")";
390     }
391     else
392     {
393         size_t size = type.getObjectSize();
394         // If the type's size is more than 1, the type needs to be written with parantheses. This
395         // applies for vectors, matrices, and arrays.
396         bool writeType = size > 1;
397         if (writeType)
398         {
399             emitType(type);
400             mSink << "(";
401         }
402         constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
403         if (writeType)
404         {
405             mSink << ")";
406         }
407     }
408     return constUnionCurr;
409 }
410 
visitConstantUnion(TIntermConstantUnion * constValueNode)411 void OutputWGSLTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
412 {
413     emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
414 }
415 
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)416 bool OutputWGSLTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
417 {
418     groupedTraverse(*swizzleNode->getOperand());
419     mSink << ".";
420     swizzleNode->writeOffsetsAsXYZW(&mSink);
421 
422     return false;
423 }
424 
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)425 const char *GetOperatorString(TOperator op,
426                               const TType &resultType,
427                               const TType *argType0,
428                               const TType *argType1,
429                               const TType *argType2)
430 {
431     switch (op)
432     {
433         case TOperator::EOpComma:
434             // WGSL does not have a comma operator or any other way to implement "statement list as
435             // an expression", so nested expressions will have to be pulled out into statements.
436             UNIMPLEMENTED();
437             return "TODO_operator";
438         case TOperator::EOpAssign:
439             return "=";
440         case TOperator::EOpInitialize:
441             return "=";
442         // Compound assignments now exist: https://www.w3.org/TR/WGSL/#compound-assignment-sec
443         case TOperator::EOpAddAssign:
444             return "+=";
445         case TOperator::EOpSubAssign:
446             return "-=";
447         case TOperator::EOpMulAssign:
448             return "*=";
449         case TOperator::EOpDivAssign:
450             return "/=";
451         case TOperator::EOpIModAssign:
452             return "%=";
453         case TOperator::EOpBitShiftLeftAssign:
454             return "<<=";
455         case TOperator::EOpBitShiftRightAssign:
456             return ">>=";
457         case TOperator::EOpBitwiseAndAssign:
458             return "&=";
459         case TOperator::EOpBitwiseXorAssign:
460             return "^=";
461         case TOperator::EOpBitwiseOrAssign:
462             return "|=";
463         case TOperator::EOpAdd:
464             return "+";
465         case TOperator::EOpSub:
466             return "-";
467         case TOperator::EOpMul:
468             return "*";
469         case TOperator::EOpDiv:
470             return "/";
471         // TODO(anglebug.com/42267100): Works different from GLSL for negative numbers.
472         // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=not%20WGSL%3B%20etc.-,Inconsistent%20mod/%25%20operator,-At%20first%20glance
473         // GLSL does `x - y * floor(x/y)`, WGSL does x - y * trunc(x/y).
474         case TOperator::EOpIMod:
475         case TOperator::EOpMod:
476             return "%";
477         case TOperator::EOpBitShiftLeft:
478             return "<<";
479         case TOperator::EOpBitShiftRight:
480             return ">>";
481         case TOperator::EOpBitwiseAnd:
482             return "&";
483         case TOperator::EOpBitwiseXor:
484             return "^";
485         case TOperator::EOpBitwiseOr:
486             return "|";
487         case TOperator::EOpLessThan:
488             return "<";
489         case TOperator::EOpGreaterThan:
490             return ">";
491         case TOperator::EOpLessThanEqual:
492             return "<=";
493         case TOperator::EOpGreaterThanEqual:
494             return ">=";
495         // Component-wise comparisons are done with regular infix operators in WGSL:
496         // https://www.w3.org/TR/WGSL/#comparison-expr
497         case TOperator::EOpLessThanComponentWise:
498             return "<";
499         case TOperator::EOpLessThanEqualComponentWise:
500             return "<=";
501         case TOperator::EOpGreaterThanEqualComponentWise:
502             return ">=";
503         case TOperator::EOpGreaterThanComponentWise:
504             return ">";
505         case TOperator::EOpLogicalOr:
506             return "||";
507         // Logical XOR is only applied to boolean expressions so it's the same as "not equals".
508         // Neither short-circuits.
509         case TOperator::EOpLogicalXor:
510             return "!=";
511         case TOperator::EOpLogicalAnd:
512             return "&&";
513         case TOperator::EOpNegative:
514             return "-";
515         case TOperator::EOpPositive:
516             if (argType0->isMatrix())
517             {
518                 return "";
519             }
520             return "+";
521         case TOperator::EOpLogicalNot:
522             return "!";
523         // Component-wise not done with normal prefix unary operator in WGSL:
524         // https://www.w3.org/TR/WGSL/#logical-expr
525         case TOperator::EOpNotComponentWise:
526             return "!";
527         case TOperator::EOpBitwiseNot:
528             return "~";
529         // TODO(anglebug.com/42267100): increment operations cannot be used as expressions in WGSL.
530         case TOperator::EOpPostIncrement:
531             return "++";
532         case TOperator::EOpPostDecrement:
533             return "--";
534         case TOperator::EOpPreIncrement:
535         case TOperator::EOpPreDecrement:
536             // TODO(anglebug.com/42267100): pre increments and decrements do not exist in WGSL.
537             UNIMPLEMENTED();
538             return "TODO_operator";
539         case TOperator::EOpVectorTimesScalarAssign:
540             return "*=";
541         case TOperator::EOpVectorTimesMatrixAssign:
542             return "*=";
543         case TOperator::EOpMatrixTimesScalarAssign:
544             return "*=";
545         case TOperator::EOpMatrixTimesMatrixAssign:
546             return "*=";
547         case TOperator::EOpVectorTimesScalar:
548             return "*";
549         case TOperator::EOpVectorTimesMatrix:
550             return "*";
551         case TOperator::EOpMatrixTimesVector:
552             return "*";
553         case TOperator::EOpMatrixTimesScalar:
554             return "*";
555         case TOperator::EOpMatrixTimesMatrix:
556             return "*";
557         case TOperator::EOpEqualComponentWise:
558             return "==";
559         case TOperator::EOpNotEqualComponentWise:
560             return "!=";
561 
562         // TODO(anglebug.com/42267100): structs, matrices, and arrays are not comparable with WGSL's
563         // == or !=. Comparing vectors results in a component-wise comparison returning a boolean
564         // vector, which is different from GLSL (which use equal(vec, vec) for component-wise
565         // comparison)
566         case TOperator::EOpEqual:
567             if ((argType0->isVector() && argType1->isVector()) ||
568                 (argType0->getStruct() && argType1->getStruct()) ||
569                 (argType0->isArray() && argType1->isArray()) ||
570                 (argType0->isMatrix() && argType1->isMatrix()))
571 
572             {
573                 UNIMPLEMENTED();
574                 return "TODO_operator";
575             }
576 
577             return "==";
578 
579         case TOperator::EOpNotEqual:
580             if ((argType0->isVector() && argType1->isVector()) ||
581                 (argType0->getStruct() && argType1->getStruct()) ||
582                 (argType0->isArray() && argType1->isArray()) ||
583                 (argType0->isMatrix() && argType1->isMatrix()))
584             {
585                 UNIMPLEMENTED();
586                 return "TODO_operator";
587             }
588             return "!=";
589 
590         case TOperator::EOpKill:
591         case TOperator::EOpReturn:
592         case TOperator::EOpBreak:
593         case TOperator::EOpContinue:
594             // These should all be emitted in visitBranch().
595             UNREACHABLE();
596             return "UNREACHABLE_operator";
597         case TOperator::EOpRadians:
598             return "radians";
599         case TOperator::EOpDegrees:
600             return "degrees";
601         case TOperator::EOpAtan:
602             return argType1 == nullptr ? "atan" : "atan2";
603         case TOperator::EOpRefract:
604             return argType0->isVector() ? "refract" : "TODO_operator";
605         case TOperator::EOpDistance:
606             return "distance";
607         case TOperator::EOpLength:
608             return "length";
609         case TOperator::EOpDot:
610             return argType0->isVector() ? "dot" : "*";
611         case TOperator::EOpNormalize:
612             return argType0->isVector() ? "normalize" : "sign";
613         case TOperator::EOpFaceforward:
614             return argType0->isVector() ? "faceForward" : "TODO_Operator";
615         case TOperator::EOpReflect:
616             return argType0->isVector() ? "reflect" : "TODO_Operator";
617         case TOperator::EOpMatrixCompMult:
618             return "TODO_Operator";
619         case TOperator::EOpOuterProduct:
620             return "TODO_Operator";
621         case TOperator::EOpSign:
622             return "sign";
623 
624         case TOperator::EOpAbs:
625             return "abs";
626         case TOperator::EOpAll:
627             return "all";
628         case TOperator::EOpAny:
629             return "any";
630         case TOperator::EOpSin:
631             return "sin";
632         case TOperator::EOpCos:
633             return "cos";
634         case TOperator::EOpTan:
635             return "tan";
636         case TOperator::EOpAsin:
637             return "asin";
638         case TOperator::EOpAcos:
639             return "acos";
640         case TOperator::EOpSinh:
641             return "sinh";
642         case TOperator::EOpCosh:
643             return "cosh";
644         case TOperator::EOpTanh:
645             return "tanh";
646         case TOperator::EOpAsinh:
647             return "asinh";
648         case TOperator::EOpAcosh:
649             return "acosh";
650         case TOperator::EOpAtanh:
651             return "atanh";
652         case TOperator::EOpFma:
653             return "fma";
654         // TODO(anglebug.com/42267100): Won't accept pow(vec<f32>, f32).
655         // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=Similarly%20pow(vec3%3Cf32%3E%2C%20f32)%20works%20in%20GLSL%20but%20not%20WGSL
656         case TOperator::EOpPow:
657             return "pow";  // GLSL's pow excludes negative x
658         case TOperator::EOpExp:
659             return "exp";
660         case TOperator::EOpExp2:
661             return "exp2";
662         case TOperator::EOpLog:
663             return "log";
664         case TOperator::EOpLog2:
665             return "log2";
666         case TOperator::EOpSqrt:
667             return "sqrt";
668         case TOperator::EOpFloor:
669             return "floor";
670         case TOperator::EOpTrunc:
671             return "trunc";
672         case TOperator::EOpCeil:
673             return "ceil";
674         case TOperator::EOpFract:
675             return "fract";
676         case TOperator::EOpMin:
677             return "min";
678         case TOperator::EOpMax:
679             return "max";
680         case TOperator::EOpRound:
681             return "round";  // TODO(anglebug.com/42267100): this is wrong and must round away from
682                              // zero if there is a tie. This always rounds to the even number.
683         case TOperator::EOpRoundEven:
684             return "round";
685         // TODO(anglebug.com/42267100):
686         // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=clamp(vec2%3Cf32%3E%2C%20f32%2C%20f32)%20works%20in%20GLSL%20but%20not%20WGSL%3B%20etc.
687         // Need to expand clamp(vec<f32>, low : f32, high : f32) ->
688         // clamp(vec<f32>, vec<f32>(low), vec<f32>(high))
689         case TOperator::EOpClamp:
690             return "clamp";
691         case TOperator::EOpSaturate:
692             return "saturate";
693         case TOperator::EOpMix:
694             if (!argType1->isScalar() && argType2 && argType2->getBasicType() == EbtBool)
695             {
696                 return "TODO_Operator";
697             }
698             return "mix";
699         case TOperator::EOpStep:
700             return "step";
701         case TOperator::EOpSmoothstep:
702             return "smoothstep";
703         case TOperator::EOpModf:
704             UNIMPLEMENTED();  // TODO(anglebug.com/42267100): in WGSL this returns a struct, GLSL it
705                               // uses a return value and an outparam
706             return "modf";
707         case TOperator::EOpIsnan:
708         case TOperator::EOpIsinf:
709             UNIMPLEMENTED();  // TODO(anglebug.com/42267100): WGSL does not allow NaNs or infinity.
710                               // What to do about shaders that require this?
711             // Implementations are allowed to assume overflow, infinities, and NaNs are not present
712             // at runtime, however. https://www.w3.org/TR/WGSL/#floating-point-evaluation
713             return "TODO_Operator";
714         case TOperator::EOpLdexp:
715             // TODO(anglebug.com/42267100): won't accept first arg vector, second arg scalar
716             return "ldexp";
717         case TOperator::EOpFrexp:
718             return "frexp";  // TODO(anglebug.com/42267100): returns a struct
719         case TOperator::EOpInversesqrt:
720             return "inverseSqrt";
721         case TOperator::EOpCross:
722             return "cross";
723             // TODO(anglebug.com/42267100): are these the same? dpdxCoarse() vs dpdxFine()?
724         case TOperator::EOpDFdx:
725             return "dpdx";
726         case TOperator::EOpDFdy:
727             return "dpdy";
728         case TOperator::EOpFwidth:
729             return "fwidth";
730         case TOperator::EOpTranspose:
731             return "transpose";
732         case TOperator::EOpDeterminant:
733             return "determinant";
734 
735         case TOperator::EOpInverse:
736             return "TODO_Operator";  // No builtin invert().
737                                      // https://github.com/gpuweb/gpuweb/issues/4115
738 
739         // TODO(anglebug.com/42267100): these interpolateAt*() are not builtin
740         case TOperator::EOpInterpolateAtCentroid:
741             return "TODO_Operator";
742         case TOperator::EOpInterpolateAtSample:
743             return "TODO_Operator";
744         case TOperator::EOpInterpolateAtOffset:
745             return "TODO_Operator";
746         case TOperator::EOpInterpolateAtCenter:
747             return "TODO_Operator";
748 
749         case TOperator::EOpFloatBitsToInt:
750         case TOperator::EOpFloatBitsToUint:
751         case TOperator::EOpIntBitsToFloat:
752         case TOperator::EOpUintBitsToFloat:
753         {
754 #define BITCAST_SCALAR()                   \
755     do                                     \
756         switch (resultType.getBasicType()) \
757         {                                  \
758             case TBasicType::EbtInt:       \
759                 return "bitcast<i32>";     \
760             case TBasicType::EbtUInt:      \
761                 return "bitcast<u32>";     \
762             case TBasicType::EbtFloat:     \
763                 return "bitcast<f32>";     \
764             default:                       \
765                 UNIMPLEMENTED();           \
766                 return "TOperator_TODO";   \
767         }                                  \
768     while (false)
769 
770 #define BITCAST_VECTOR(vecSize)                        \
771     do                                                 \
772         switch (resultType.getBasicType())             \
773         {                                              \
774             case TBasicType::EbtInt:                   \
775                 return "bitcast<vec" vecSize "<i32>>"; \
776             case TBasicType::EbtUInt:                  \
777                 return "bitcast<vec" vecSize "<u32>>"; \
778             case TBasicType::EbtFloat:                 \
779                 return "bitcast<vec" vecSize "<f32>>"; \
780             default:                                   \
781                 UNIMPLEMENTED();                       \
782                 return "TOperator_TODO";               \
783         }                                              \
784     while (false)
785 
786             if (resultType.isScalar())
787             {
788                 BITCAST_SCALAR();
789             }
790             else if (resultType.isVector())
791             {
792                 switch (resultType.getNominalSize())
793                 {
794                     case 2:
795                         BITCAST_VECTOR("2");
796                     case 3:
797                         BITCAST_VECTOR("3");
798                     case 4:
799                         BITCAST_VECTOR("4");
800                     default:
801                         UNREACHABLE();
802                         return nullptr;
803                 }
804             }
805             else
806             {
807                 UNIMPLEMENTED();
808                 return "TOperator_TODO";
809             }
810 
811 #undef BITCAST_SCALAR
812 #undef BITCAST_VECTOR
813         }
814 
815         case TOperator::EOpPackUnorm2x16:
816             return "pack2x16unorm";
817         case TOperator::EOpPackSnorm2x16:
818             return "pack2x16snorm";
819 
820         case TOperator::EOpPackUnorm4x8:
821             return "pack4x8unorm";
822         case TOperator::EOpPackSnorm4x8:
823             return "pack4x8snorm";
824 
825         case TOperator::EOpUnpackUnorm2x16:
826             return "unpack2x16unorm";
827         case TOperator::EOpUnpackSnorm2x16:
828             return "unpack2x16snorm";
829 
830         case TOperator::EOpUnpackUnorm4x8:
831             return "unpack4x8unorm";
832         case TOperator::EOpUnpackSnorm4x8:
833             return "unpack4x8snorm";
834 
835         case TOperator::EOpPackHalf2x16:
836             return "pack2x16float";
837         case TOperator::EOpUnpackHalf2x16:
838             return "unpack2x16float";
839 
840         case TOperator::EOpBarrier:
841             UNREACHABLE();
842             return "TOperator_TODO";
843         case TOperator::EOpMemoryBarrier:
844             // TODO(anglebug.com/42267100): does this exist in WGPU? Device-scoped memory barrier?
845             // Maybe storageBarrier()?
846             UNREACHABLE();
847             return "TOperator_TODO";
848         case TOperator::EOpGroupMemoryBarrier:
849             return "workgroupBarrier";
850         case TOperator::EOpMemoryBarrierAtomicCounter:
851         case TOperator::EOpMemoryBarrierBuffer:
852         case TOperator::EOpMemoryBarrierShared:
853             UNREACHABLE();
854             return "TOperator_TODO";
855         case TOperator::EOpAtomicAdd:
856             return "atomicAdd";
857         case TOperator::EOpAtomicMin:
858             return "atomicMin";
859         case TOperator::EOpAtomicMax:
860             return "atomicMax";
861         case TOperator::EOpAtomicAnd:
862             return "atomicAnd";
863         case TOperator::EOpAtomicOr:
864             return "atomicOr";
865         case TOperator::EOpAtomicXor:
866             return "atomicXor";
867         case TOperator::EOpAtomicExchange:
868             return "atomicExchange";
869         case TOperator::EOpAtomicCompSwap:
870             return "atomicCompareExchangeWeak";  // TODO(anglebug.com/42267100): returns a struct.
871         case TOperator::EOpBitfieldExtract:
872         case TOperator::EOpBitfieldInsert:
873         case TOperator::EOpBitfieldReverse:
874         case TOperator::EOpBitCount:
875         case TOperator::EOpFindLSB:
876         case TOperator::EOpFindMSB:
877         case TOperator::EOpUaddCarry:
878         case TOperator::EOpUsubBorrow:
879         case TOperator::EOpUmulExtended:
880         case TOperator::EOpImulExtended:
881         case TOperator::EOpEmitVertex:
882         case TOperator::EOpEndPrimitive:
883         case TOperator::EOpArrayLength:
884             UNIMPLEMENTED();
885             return "TOperator_TODO";
886 
887         case TOperator::EOpNull:
888         case TOperator::EOpConstruct:
889         case TOperator::EOpCallFunctionInAST:
890         case TOperator::EOpCallInternalRawFunction:
891         case TOperator::EOpIndexDirect:
892         case TOperator::EOpIndexIndirect:
893         case TOperator::EOpIndexDirectStruct:
894         case TOperator::EOpIndexDirectInterfaceBlock:
895             UNREACHABLE();
896             return nullptr;
897         default:
898             // Any other built-in function.
899             return nullptr;
900     }
901 }
902 
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)903 bool IsSymbolicOperator(TOperator op,
904                         const TType &resultType,
905                         const TType *argType0,
906                         const TType *argType1)
907 {
908     const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
909     if (operatorString == nullptr)
910     {
911         return false;
912     }
913     return !std::isalnum(operatorString[0]);
914 }
915 
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)916 const TField &OutputWGSLTraverser::getDirectField(const TIntermTyped &fieldsNode,
917                                                   TIntermTyped &indexNode)
918 {
919     const TType &fieldsType = fieldsNode.getType();
920 
921     const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
922     if (fieldListCollection == nullptr)
923     {
924         fieldListCollection = fieldsType.getInterfaceBlock();
925     }
926     ASSERT(fieldListCollection);
927 
928     const TIntermConstantUnion *indexNodeAsConstantUnion = indexNode.getAsConstantUnion();
929     ASSERT(indexNodeAsConstantUnion);
930     const TConstantUnion &index = *indexNodeAsConstantUnion->getConstantValue();
931 
932     ASSERT(index.getType() == TBasicType::EbtInt);
933 
934     const TFieldList &fieldList = fieldListCollection->fields();
935     const int indexVal          = index.getIConst();
936     const TField &field         = *fieldList[indexVal];
937 
938     return field;
939 }
940 
emitArrayIndex(TIntermTyped & leftNode,TIntermTyped & rightNode)941 void OutputWGSLTraverser::emitArrayIndex(TIntermTyped &leftNode, TIntermTyped &rightNode)
942 {
943 
944     {
945         TType leftType = leftNode.getType();
946         groupedTraverse(leftNode);
947         mSink << "[";
948         const TConstantUnion *constIndex = rightNode.getConstantValue();
949         // If the array index is a constant that we can statically verify is within array
950         // bounds, just emit that constant.
951         if (!leftType.isUnsizedArray() && constIndex != nullptr &&
952             constIndex->getType() == EbtInt && constIndex->getIConst() >= 0 &&
953             constIndex->getIConst() < static_cast<int>(leftType.isArray()
954                                                            ? leftType.getOutermostArraySize()
955                                                            : leftType.getNominalSize()))
956         {
957             emitSingleConstant(constIndex);
958         }
959         else
960         {
961             // If the array index is not a constant within the bounds of the array, clamp the
962             // index.
963             mSink << "clamp(";
964             groupedTraverse(rightNode);
965             mSink << ", 0, ";
966             // Now find the array size and clamp it.
967             if (leftType.isUnsizedArray())
968             {
969                 // TODO(anglebug.com/42267100): This is a bug to traverse the `leftNode` a
970                 // second time if `leftNode` has side effects (and could also have performance
971                 // implications). This should be stored in a temporary variable. This might also
972                 // be a bug in the MSL shader compiler.
973                 mSink << "arrayLength(&";
974                 groupedTraverse(leftNode);
975                 mSink << ")";
976             }
977             else
978             {
979                 uint32_t maxSize;
980                 if (leftType.isArray())
981                 {
982                     maxSize = leftType.getOutermostArraySize() - 1;
983                 }
984                 else
985                 {
986                     maxSize = leftType.getNominalSize() - 1;
987                 }
988                 mSink << maxSize;
989             }
990             // End the clamp() function.
991             mSink << ")";
992         }
993         // End the array index operation.
994         mSink << "]";
995     }
996 }
997 
visitBinary(Visit,TIntermBinary * binaryNode)998 bool OutputWGSLTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
999 {
1000     const TOperator op      = binaryNode->getOp();
1001     TIntermTyped &leftNode  = *binaryNode->getLeft();
1002     TIntermTyped &rightNode = *binaryNode->getRight();
1003 
1004     switch (op)
1005     {
1006         case TOperator::EOpIndexDirectStruct:
1007         case TOperator::EOpIndexDirectInterfaceBlock:
1008             groupedTraverse(leftNode);
1009             mSink << ".";
1010             WriteNameOf(mSink, getDirectField(leftNode, rightNode));
1011             break;
1012 
1013         case TOperator::EOpIndexDirect:
1014         case TOperator::EOpIndexIndirect:
1015             emitArrayIndex(leftNode, rightNode);
1016             break;
1017 
1018         default:
1019         {
1020             const TType &resultType = binaryNode->getType();
1021             const TType &leftType   = leftNode.getType();
1022             const TType &rightType  = rightNode.getType();
1023 
1024             // x * y, x ^ y, etc.
1025             if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1026             {
1027                 groupedTraverse(leftNode);
1028                 if (op != TOperator::EOpComma)
1029                 {
1030                     mSink << " ";
1031                 }
1032                 mSink << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1033                 groupedTraverse(rightNode);
1034             }
1035             // E.g. builtin function calls
1036             else
1037             {
1038                 mSink << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1039                 leftNode.traverse(this);
1040                 mSink << ", ";
1041                 rightNode.traverse(this);
1042                 mSink << ")";
1043             }
1044         }
1045     }
1046 
1047     return false;
1048 }
1049 
IsPostfix(TOperator op)1050 bool IsPostfix(TOperator op)
1051 {
1052     switch (op)
1053     {
1054         case TOperator::EOpPostIncrement:
1055         case TOperator::EOpPostDecrement:
1056             return true;
1057 
1058         default:
1059             return false;
1060     }
1061 }
1062 
visitUnary(Visit,TIntermUnary * unaryNode)1063 bool OutputWGSLTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1064 {
1065     const TOperator op      = unaryNode->getOp();
1066     const TType &resultType = unaryNode->getType();
1067 
1068     TIntermTyped &arg    = *unaryNode->getOperand();
1069     const TType &argType = arg.getType();
1070 
1071     const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1072 
1073     // Examples: -x, ~x, ~x
1074     if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1075     {
1076         const bool postfix = IsPostfix(op);
1077         if (!postfix)
1078         {
1079             mSink << name;
1080         }
1081         groupedTraverse(arg);
1082         if (postfix)
1083         {
1084             mSink << name;
1085         }
1086     }
1087     else
1088     {
1089         mSink << name << "(";
1090         arg.traverse(this);
1091         mSink << ")";
1092     }
1093 
1094     return false;
1095 }
1096 
visitTernary(Visit,TIntermTernary * conditionalNode)1097 bool OutputWGSLTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1098 {
1099     // WGSL does not have a ternary. https://github.com/gpuweb/gpuweb/issues/3747
1100     // The select() builtin is not short circuiting. Maybe we can get if () {} else {} as an
1101     // expression, which would also solve the comma operator problem.
1102     // TODO(anglebug.com/42267100): as mentioned above this is not correct if the operands have side
1103     // effects. Even if they don't have side effects it could have performance implications.
1104     mSink << "select(";
1105     groupedTraverse(*conditionalNode->getTrueExpression());
1106     mSink << ", ";
1107     groupedTraverse(*conditionalNode->getFalseExpression());
1108     mSink << ", ";
1109     groupedTraverse(*conditionalNode->getCondition());
1110     mSink << ")";
1111 
1112     return false;
1113 }
1114 
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1115 bool OutputWGSLTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1116 {
1117     TIntermTyped &condNode = *ifThenElseNode->getCondition();
1118     TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1119     TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1120 
1121     mSink << "if (";
1122     condNode.traverse(this);
1123     mSink << ")";
1124 
1125     if (thenNode)
1126     {
1127         mSink << "\n";
1128         thenNode->traverse(this);
1129     }
1130     else
1131     {
1132         mSink << " {}";
1133     }
1134 
1135     if (elseNode)
1136     {
1137         mSink << "\n";
1138         emitIndentation();
1139         mSink << "else\n";
1140         elseNode->traverse(this);
1141     }
1142 
1143     return false;
1144 }
1145 
visitSwitch(Visit,TIntermSwitch * switchNode)1146 bool OutputWGSLTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1147 {
1148     TIntermBlock &stmtList = *switchNode->getStatementList();
1149 
1150     emitIndentation();
1151     mSink << "switch ";
1152     switchNode->getInit()->traverse(this);
1153     mSink << "\n";
1154 
1155     emitOpenBrace();
1156 
1157     // TODO(anglebug.com/42267100): Case statements that fall through need to combined into a single
1158     // case statement with multiple labels.
1159 
1160     const size_t stmtCount = stmtList.getChildCount();
1161     bool inCaseList        = false;
1162     size_t currStmt        = 0;
1163     while (currStmt < stmtCount)
1164     {
1165         TIntermNode &stmtNode = *stmtList.getChildNode(currStmt);
1166         TIntermCase *caseNode = stmtNode.getAsCaseNode();
1167         if (caseNode)
1168         {
1169             if (inCaseList)
1170             {
1171                 mSink << ", ";
1172             }
1173             else
1174             {
1175                 emitIndentation();
1176                 mSink << "case ";
1177                 inCaseList = true;
1178             }
1179             caseNode->traverse(this);
1180 
1181             // Process the next statement.
1182             currStmt++;
1183         }
1184         else
1185         {
1186             // The current statement is not a case statement, end the current case list and emit all
1187             // the code until the next case statement. WGSL requires braces around the case
1188             // statement's code.
1189             ASSERT(inCaseList);
1190             inCaseList = false;
1191             mSink << ":\n";
1192 
1193             // Count the statements until the next case (or the end of the switch) and emit them as
1194             // a block. This assumes that the current statement list will never fallthrough to the
1195             // next case statement.
1196             size_t nextCaseStmt = currStmt + 1;
1197             for (;
1198                  nextCaseStmt < stmtCount && !stmtList.getChildNode(nextCaseStmt)->getAsCaseNode();
1199                  nextCaseStmt++)
1200             {
1201             }
1202             TSpan<TIntermNode *> stmtListView(&stmtList.getSequence()->at(currStmt),
1203                                               nextCaseStmt - currStmt);
1204             emitBlock(stmtListView);
1205             mSink << "\n";
1206 
1207             // Skip to the next case statement.
1208             currStmt = nextCaseStmt;
1209         }
1210     }
1211 
1212     emitCloseBrace();
1213 
1214     return false;
1215 }
1216 
visitCase(Visit,TIntermCase * caseNode)1217 bool OutputWGSLTraverser::visitCase(Visit, TIntermCase *caseNode)
1218 {
1219     // "case" will have been emitted in the visitSwitch() override.
1220 
1221     if (caseNode->hasCondition())
1222     {
1223         TIntermTyped *condExpr = caseNode->getCondition();
1224         condExpr->traverse(this);
1225     }
1226     else
1227     {
1228         mSink << "default";
1229     }
1230 
1231     return false;
1232 }
1233 
emitFunctionReturn(const TFunction & func)1234 void OutputWGSLTraverser::emitFunctionReturn(const TFunction &func)
1235 {
1236     const TType &returnType = func.getReturnType();
1237     if (returnType.getBasicType() == EbtVoid)
1238     {
1239         return;
1240     }
1241     mSink << " -> ";
1242     emitType(returnType);
1243 }
1244 
1245 // TODO(anglebug.com/42267100): Function overloads are not supported in WGSL, so function names
1246 // should either be emitted mangled or overloaded functions should be renamed in the AST as a
1247 // pre-pass. As of Apr 2024, WGSL function overloads are "not coming soon"
1248 // (https://github.com/gpuweb/gpuweb/issues/876).
emitFunctionSignature(const TFunction & func)1249 void OutputWGSLTraverser::emitFunctionSignature(const TFunction &func)
1250 {
1251     mSink << "fn ";
1252 
1253     WriteNameOf(mSink, func);
1254     mSink << "(";
1255 
1256     bool emitComma          = false;
1257     const size_t paramCount = func.getParamCount();
1258     for (size_t i = 0; i < paramCount; ++i)
1259     {
1260         if (emitComma)
1261         {
1262             mSink << ", ";
1263         }
1264         emitComma = true;
1265 
1266         const TVariable &param = *func.getParam(i);
1267         emitFunctionParameter(func, param);
1268     }
1269 
1270     mSink << ")";
1271 
1272     emitFunctionReturn(func);
1273 }
1274 
emitFunctionParameter(const TFunction & func,const TVariable & param)1275 void OutputWGSLTraverser::emitFunctionParameter(const TFunction &func, const TVariable &param)
1276 {
1277     // TODO(anglebug.com/42267100): function parameters are immutable and will need to be renamed if
1278     // they are mutated.
1279     EmitVariableDeclarationConfig evdConfig;
1280     evdConfig.isParameter = true;
1281     emitVariableDeclaration({param.symbolType(), param.name(), param.getType()}, evdConfig);
1282 }
1283 
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)1284 void OutputWGSLTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
1285 {
1286     const TFunction &func = *funcProtoNode->getFunction();
1287 
1288     emitIndentation();
1289     // TODO(anglebug.com/42267100): output correct signature for main() if main() is declared as a
1290     // function prototype, or perhaps just emit nothing.
1291     emitFunctionSignature(func);
1292 }
1293 
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)1294 bool OutputWGSLTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
1295 {
1296     const TFunction &func = *funcDefNode->getFunction();
1297     TIntermBlock &body    = *funcDefNode->getBody();
1298 
1299     emitIndentation();
1300     emitFunctionSignature(func);
1301     mSink << "\n";
1302     body.traverse(this);
1303 
1304     return false;
1305 }
1306 
visitAggregate(Visit,TIntermAggregate * aggregateNode)1307 bool OutputWGSLTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
1308 {
1309     const TIntermSequence &args = *aggregateNode->getSequence();
1310 
1311     auto emitArgList = [&]() {
1312         mSink << "(";
1313 
1314         bool emitComma = false;
1315         for (TIntermNode *arg : args)
1316         {
1317             if (emitComma)
1318             {
1319                 mSink << ", ";
1320             }
1321             emitComma = true;
1322             arg->traverse(this);
1323         }
1324 
1325         mSink << ")";
1326     };
1327 
1328     const TType &retType = aggregateNode->getType();
1329 
1330     if (aggregateNode->isConstructor())
1331     {
1332 
1333         emitType(retType);
1334         emitArgList();
1335 
1336         return false;
1337     }
1338     else
1339     {
1340         const TOperator op = aggregateNode->getOp();
1341         switch (op)
1342         {
1343             case TOperator::EOpCallFunctionInAST:
1344                 WriteNameOf(mSink, *aggregateNode->getFunction());
1345                 emitArgList();
1346                 return false;
1347 
1348             default:
1349                 // Do not allow raw function calls, i.e. calls to functions
1350                 // not present in the AST.
1351                 ASSERT(op != TOperator::EOpCallInternalRawFunction);
1352                 auto getArgType = [&](size_t index) -> const TType * {
1353                     if (index < args.size())
1354                     {
1355                         TIntermTyped *arg = args[index]->getAsTyped();
1356                         ASSERT(arg);
1357                         return &arg->getType();
1358                     }
1359                     return nullptr;
1360                 };
1361 
1362                 const TType *argType0 = getArgType(0);
1363                 const TType *argType1 = getArgType(1);
1364                 const TType *argType2 = getArgType(2);
1365 
1366                 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
1367 
1368                 if (IsSymbolicOperator(op, retType, argType0, argType1))
1369                 {
1370                     switch (args.size())
1371                     {
1372                         case 1:
1373                         {
1374                             TIntermNode &operandNode = *aggregateNode->getChildNode(0);
1375                             if (IsPostfix(op))
1376                             {
1377                                 mSink << opName;
1378                                 groupedTraverse(operandNode);
1379                             }
1380                             else
1381                             {
1382                                 groupedTraverse(operandNode);
1383                                 mSink << opName;
1384                             }
1385                             return false;
1386                         }
1387 
1388                         case 2:
1389                         {
1390                             // symbolic operators with 2 args are emitted with infix notation.
1391                             TIntermNode &leftNode  = *aggregateNode->getChildNode(0);
1392                             TIntermNode &rightNode = *aggregateNode->getChildNode(1);
1393                             groupedTraverse(leftNode);
1394                             mSink << " " << opName << " ";
1395                             groupedTraverse(rightNode);
1396                             return false;
1397                         }
1398 
1399                         default:
1400                             UNREACHABLE();
1401                             return false;
1402                     }
1403                 }
1404                 else
1405                 {
1406                     if (opName == nullptr)
1407                     {
1408                         // TODO(anglebug.com/42267100): opName should not be allowed to be nullptr
1409                         // here, but for now not all builtins are mapped to a string.
1410                         opName = "TODO_Operator";
1411                     }
1412                     // If the operator is not symbolic then it is a builtin that uses function call
1413                     // syntax: builtin(arg1, arg2, ..);
1414                     mSink << opName;
1415                     emitArgList();
1416                     return false;
1417                 }
1418         }
1419     }
1420 }
1421 
emitBlock(TSpan<TIntermNode * > nodes)1422 bool OutputWGSLTraverser::emitBlock(TSpan<TIntermNode *> nodes)
1423 {
1424     ASSERT(mIndentLevel >= -1);
1425     const bool isGlobalScope = mIndentLevel == -1;
1426 
1427     if (isGlobalScope)
1428     {
1429         ++mIndentLevel;
1430     }
1431     else
1432     {
1433         emitOpenBrace();
1434     }
1435 
1436     TIntermNode *prevStmtNode = nullptr;
1437 
1438     const size_t stmtCount = nodes.size();
1439     for (size_t i = 0; i < stmtCount; ++i)
1440     {
1441         TIntermNode &stmtNode = *nodes[i];
1442 
1443         if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
1444         {
1445             mSink << "\n";
1446         }
1447         const bool isCase = stmtNode.getAsCaseNode();
1448         mIndentLevel -= isCase;
1449         emitIndentation();
1450         mIndentLevel += isCase;
1451         stmtNode.traverse(this);
1452         if (RequiresSemicolonTerminator(stmtNode))
1453         {
1454             mSink << ";";
1455         }
1456         mSink << "\n";
1457 
1458         prevStmtNode = &stmtNode;
1459     }
1460 
1461     if (isGlobalScope)
1462     {
1463         ASSERT(mIndentLevel == 0);
1464         --mIndentLevel;
1465     }
1466     else
1467     {
1468         emitCloseBrace();
1469     }
1470 
1471     return false;
1472 }
1473 
visitBlock(Visit,TIntermBlock * blockNode)1474 bool OutputWGSLTraverser::visitBlock(Visit, TIntermBlock *blockNode)
1475 {
1476     return emitBlock(TSpan(blockNode->getSequence()->data(), blockNode->getSequence()->size()));
1477 }
1478 
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)1479 bool OutputWGSLTraverser::visitGlobalQualifierDeclaration(Visit,
1480                                                           TIntermGlobalQualifierDeclaration *)
1481 {
1482     return false;
1483 }
1484 
emitStructDeclaration(const TType & type)1485 void OutputWGSLTraverser::emitStructDeclaration(const TType &type)
1486 {
1487     ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1488     ASSERT(type.isStructSpecifier());
1489 
1490     mSink << "struct ";
1491     emitBareTypeName(type);
1492 
1493     mSink << "\n";
1494     emitOpenBrace();
1495 
1496     const TStructure &structure = *type.getStruct();
1497     bool isInUniformAddressSpace =
1498         mUniformBlockMetadata->structsInUniformAddressSpace.count(structure.uniqueId().get()) != 0;
1499 
1500     bool alignTo16InUniformAddressSpace = true;
1501     for (const TField *field : structure.fields())
1502     {
1503         emitIndentation();
1504         // If this struct is used in the uniform address space, it must obey the uniform address
1505         // space's layout constaints (https://www.w3.org/TR/WGSL/#address-space-layout-constraints).
1506         // WGSL's address space layout constraints nearly match std140, and the places they don't
1507         // are handled elsewhere.
1508         if (isInUniformAddressSpace)
1509         {
1510             // Here, the field must be aligned to 16 if:
1511             // 1. The field is a struct or array
1512             // 2. The previous field is a struct
1513             // 3. The field is the first in the struct (for convenience).
1514             if (field->type()->getStruct() || field->type()->isArray())
1515             {
1516                 alignTo16InUniformAddressSpace = true;
1517             }
1518             if (alignTo16InUniformAddressSpace)
1519             {
1520                 mSink << "@align(16) ";
1521             }
1522 
1523             // If this field is a struct, the next member should be aligned to 16.
1524             alignTo16InUniformAddressSpace = field->type()->getStruct();
1525         }
1526 
1527         // TODO(anglebug.com/42267100): emit qualifiers.
1528         EmitVariableDeclarationConfig evdConfig;
1529         evdConfig.disableStructSpecifier = true;
1530         emitVariableDeclaration({field->symbolType(), field->name(), *field->type()}, evdConfig);
1531         mSink << ",\n";
1532     }
1533 
1534     emitCloseBrace();
1535 }
1536 
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1537 void OutputWGSLTraverser::emitVariableDeclaration(const VarDecl &decl,
1538                                                   const EmitVariableDeclarationConfig &evdConfig)
1539 {
1540     const TBasicType basicType = decl.type.getBasicType();
1541 
1542     if (decl.type.getQualifier() == EvqUniform)
1543     {
1544         // Uniforms are declared in a pre-pass, and don't need to be outputted here.
1545         return;
1546     }
1547 
1548     if (basicType == TBasicType::EbtStruct && decl.type.isStructSpecifier() &&
1549         !evdConfig.disableStructSpecifier)
1550     {
1551         // TODO(anglebug.com/42267100): in WGSL structs probably can't be declared in
1552         // function parameters or in uniform declarations or in variable declarations, or
1553         // anonymously either within other structs or within a variable declaration. Handle
1554         // these with the same AST pre-passes as other shader translators.
1555         ASSERT(!evdConfig.isParameter);
1556         emitStructDeclaration(decl.type);
1557         if (decl.symbolType != SymbolType::Empty)
1558         {
1559             mSink << " ";
1560             emitNameOf(decl);
1561         }
1562         return;
1563     }
1564 
1565     ASSERT(basicType == TBasicType::EbtStruct || decl.symbolType != SymbolType::Empty ||
1566            evdConfig.isParameter);
1567 
1568     if (evdConfig.needsVar)
1569     {
1570         // "const" and "let" probably don't need to be ever emitted because they are more for
1571         // readability, and the GLSL compiler constant folds most (all?) the consts anyway.
1572         mSink << "var";
1573         // TODO(anglebug.com/42267100): <workgroup> or <storage>?
1574         if (evdConfig.isGlobalScope)
1575         {
1576             mSink << "<private>";
1577         }
1578         mSink << " ";
1579     }
1580     else
1581     {
1582         ASSERT(!evdConfig.isGlobalScope);
1583     }
1584 
1585     if (decl.symbolType != SymbolType::Empty)
1586     {
1587         emitNameOf(decl);
1588     }
1589     mSink << " : ";
1590     emitType(decl.type);
1591 }
1592 
visitDeclaration(Visit,TIntermDeclaration * declNode)1593 bool OutputWGSLTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
1594 {
1595     ASSERT(declNode->getChildCount() == 1);
1596     TIntermNode &node = *declNode->getChildNode(0);
1597 
1598     EmitVariableDeclarationConfig evdConfig;
1599     evdConfig.needsVar      = true;
1600     evdConfig.isGlobalScope = mIndentLevel == 0;
1601 
1602     if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
1603     {
1604         const TVariable &var = symbolNode->variable();
1605         if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
1606             mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
1607         {
1608             // Some variables, like shader inputs/outputs/builtins, are declared in the WGSL source
1609             // outside of the traverser.
1610             return false;
1611         }
1612         emitVariableDeclaration({var.symbolType(), var.name(), var.getType()}, evdConfig);
1613     }
1614     else if (TIntermBinary *initNode = node.getAsBinaryNode())
1615     {
1616         ASSERT(initNode->getOp() == TOperator::EOpInitialize);
1617         TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
1618         TIntermTyped *valueNode       = initNode->getRight()->getAsTyped();
1619         ASSERT(leftSymbolNode && valueNode);
1620 
1621         const TVariable &var = leftSymbolNode->variable();
1622         if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
1623             mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
1624         {
1625             // Some variables, like shader inputs/outputs/builtins, are declared in the WGSL source
1626             // outside of the traverser.
1627             return false;
1628         }
1629 
1630         emitVariableDeclaration({var.symbolType(), var.name(), var.getType()}, evdConfig);
1631         mSink << " = ";
1632         groupedTraverse(*valueNode);
1633     }
1634     else
1635     {
1636         UNREACHABLE();
1637     }
1638 
1639     return false;
1640 }
1641 
visitLoop(Visit,TIntermLoop * loopNode)1642 bool OutputWGSLTraverser::visitLoop(Visit, TIntermLoop *loopNode)
1643 {
1644     const TLoopType loopType = loopNode->getType();
1645 
1646     switch (loopType)
1647     {
1648         case TLoopType::ELoopFor:
1649             return emitForLoop(loopNode);
1650         case TLoopType::ELoopWhile:
1651             return emitWhileLoop(loopNode);
1652         case TLoopType::ELoopDoWhile:
1653             return emulateDoWhileLoop(loopNode);
1654     }
1655 }
1656 
emitForLoop(TIntermLoop * loopNode)1657 bool OutputWGSLTraverser::emitForLoop(TIntermLoop *loopNode)
1658 {
1659     ASSERT(loopNode->getType() == TLoopType::ELoopFor);
1660 
1661     TIntermNode *initNode  = loopNode->getInit();
1662     TIntermTyped *condNode = loopNode->getCondition();
1663     TIntermTyped *exprNode = loopNode->getExpression();
1664 
1665     mSink << "for (";
1666 
1667     if (initNode)
1668     {
1669         initNode->traverse(this);
1670     }
1671     else
1672     {
1673         mSink << " ";
1674     }
1675 
1676     mSink << "; ";
1677 
1678     if (condNode)
1679     {
1680         condNode->traverse(this);
1681     }
1682 
1683     mSink << "; ";
1684 
1685     if (exprNode)
1686     {
1687         exprNode->traverse(this);
1688     }
1689 
1690     mSink << ")\n";
1691 
1692     loopNode->getBody()->traverse(this);
1693 
1694     return false;
1695 }
1696 
emitWhileLoop(TIntermLoop * loopNode)1697 bool OutputWGSLTraverser::emitWhileLoop(TIntermLoop *loopNode)
1698 {
1699     ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
1700 
1701     TIntermNode *initNode  = loopNode->getInit();
1702     TIntermTyped *condNode = loopNode->getCondition();
1703     TIntermTyped *exprNode = loopNode->getExpression();
1704     ASSERT(condNode);
1705     ASSERT(!initNode && !exprNode);
1706 
1707     emitIndentation();
1708     mSink << "while (";
1709     condNode->traverse(this);
1710     mSink << ")\n";
1711     loopNode->getBody()->traverse(this);
1712 
1713     return false;
1714 }
1715 
emulateDoWhileLoop(TIntermLoop * loopNode)1716 bool OutputWGSLTraverser::emulateDoWhileLoop(TIntermLoop *loopNode)
1717 {
1718     ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
1719 
1720     TIntermNode *initNode  = loopNode->getInit();
1721     TIntermTyped *condNode = loopNode->getCondition();
1722     TIntermTyped *exprNode = loopNode->getExpression();
1723     ASSERT(condNode);
1724     ASSERT(!initNode && !exprNode);
1725 
1726     emitIndentation();
1727     // Write an infinite loop.
1728     mSink << "loop {\n";
1729     mIndentLevel++;
1730     loopNode->getBody()->traverse(this);
1731     mSink << "\n";
1732     emitIndentation();
1733     // At the end of the loop, break if the loop condition dos not still hold.
1734     mSink << "if (!(";
1735     condNode->traverse(this);
1736     mSink << ") { break; }\n";
1737     mIndentLevel--;
1738     emitIndentation();
1739     mSink << "}";
1740 
1741     return false;
1742 }
1743 
visitBranch(Visit,TIntermBranch * branchNode)1744 bool OutputWGSLTraverser::visitBranch(Visit, TIntermBranch *branchNode)
1745 {
1746     const TOperator flowOp = branchNode->getFlowOp();
1747     TIntermTyped *exprNode = branchNode->getExpression();
1748 
1749     emitIndentation();
1750 
1751     switch (flowOp)
1752     {
1753         case TOperator::EOpKill:
1754         {
1755             ASSERT(exprNode == nullptr);
1756             mSink << "discard";
1757         }
1758         break;
1759 
1760         case TOperator::EOpReturn:
1761         {
1762             mSink << "return";
1763             if (exprNode)
1764             {
1765                 mSink << " ";
1766                 exprNode->traverse(this);
1767             }
1768         }
1769         break;
1770 
1771         case TOperator::EOpBreak:
1772         {
1773             ASSERT(exprNode == nullptr);
1774             mSink << "break";
1775         }
1776         break;
1777 
1778         case TOperator::EOpContinue:
1779         {
1780             ASSERT(exprNode == nullptr);
1781             mSink << "continue";
1782         }
1783         break;
1784 
1785         default:
1786         {
1787             UNREACHABLE();
1788         }
1789     }
1790 
1791     return false;
1792 }
1793 
visitPreprocessorDirective(TIntermPreprocessorDirective * node)1794 void OutputWGSLTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
1795 {
1796     // No preprocessor directives expected at this point.
1797     UNREACHABLE();
1798 }
1799 
emitBareTypeName(const TType & type)1800 void OutputWGSLTraverser::emitBareTypeName(const TType &type)
1801 {
1802     WriteWgslBareTypeName(mSink, type);
1803 }
1804 
emitType(const TType & type)1805 void OutputWGSLTraverser::emitType(const TType &type)
1806 {
1807     WriteWgslType(mSink, type);
1808 }
1809 
1810 }  // namespace
1811 
TranslatorWGSL(sh::GLenum type,ShShaderSpec spec,ShShaderOutput output)1812 TranslatorWGSL::TranslatorWGSL(sh::GLenum type, ShShaderSpec spec, ShShaderOutput output)
1813     : TCompiler(type, spec, output)
1814 {}
1815 
translate(TIntermBlock * root,const ShCompileOptions & compileOptions,PerformanceDiagnostics * perfDiagnostics)1816 bool TranslatorWGSL::translate(TIntermBlock *root,
1817                                const ShCompileOptions &compileOptions,
1818                                PerformanceDiagnostics *perfDiagnostics)
1819 {
1820     if (kOutputTreeBeforeTranslation)
1821     {
1822         OutputTree(root, getInfoSink().info);
1823         std::cout << getInfoSink().info.c_str();
1824     }
1825 
1826     RewritePipelineVarOutput rewritePipelineVarOutput(getShaderType());
1827 
1828     // WGSL's main() will need to take parameters or return values if any glsl (input/output)
1829     // builtin variables are used.
1830     if (!GenerateMainFunctionAndIOStructs(*this, *root, rewritePipelineVarOutput))
1831     {
1832         return false;
1833     }
1834 
1835     TInfoSinkBase &sink = getInfoSink().obj;
1836     // Start writing the output structs that will be referred to by the `traverser`'s output.'
1837     if (!rewritePipelineVarOutput.OutputStructs(sink))
1838     {
1839         return false;
1840     }
1841 
1842     if (!OutputUniformBlocks(this, root))
1843     {
1844         return false;
1845     }
1846 
1847     UniformBlockMetadata uniformBlockMetadata;
1848     if (!RecordUniformBlockMetadata(root, uniformBlockMetadata))
1849     {
1850         return false;
1851     }
1852 
1853     // Write the body of the WGSL including the GLSL main() function.
1854     OutputWGSLTraverser traverser(this, &rewritePipelineVarOutput, &uniformBlockMetadata);
1855     root->traverse(&traverser);
1856 
1857     // Write the actual WGSL main function, wgslMain(), which calls the GLSL main function.
1858     if (!rewritePipelineVarOutput.OutputMainFunction(sink))
1859     {
1860         return false;
1861     }
1862 
1863     if (kOutputTranslatedShader)
1864     {
1865         std::cout << sink.str();
1866     }
1867 
1868     return true;
1869 }
1870 
shouldFlattenPragmaStdglInvariantAll()1871 bool TranslatorWGSL::shouldFlattenPragmaStdglInvariantAll()
1872 {
1873     // Not neccesary for WGSL transformation.
1874     return false;
1875 }
1876 }  // namespace sh
1877