1 //
2 // Copyright 2024 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/wgsl/TranslatorWGSL.h"
8
9 #include <iostream>
10 #include <variant>
11
12 #include "GLSLANG/ShaderLang.h"
13 #include "common/log_utils.h"
14 #include "compiler/translator/BaseTypes.h"
15 #include "compiler/translator/Common.h"
16 #include "compiler/translator/Diagnostics.h"
17 #include "compiler/translator/ImmutableString.h"
18 #include "compiler/translator/ImmutableStringBuilder.h"
19 #include "compiler/translator/InfoSink.h"
20 #include "compiler/translator/IntermNode.h"
21 #include "compiler/translator/OutputTree.h"
22 #include "compiler/translator/StaticType.h"
23 #include "compiler/translator/SymbolUniqueId.h"
24 #include "compiler/translator/Types.h"
25 #include "compiler/translator/tree_util/BuiltIn_autogen.h"
26 #include "compiler/translator/tree_util/FindMain.h"
27 #include "compiler/translator/tree_util/IntermNode_util.h"
28 #include "compiler/translator/tree_util/IntermTraverse.h"
29 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
30 #include "compiler/translator/wgsl/OutputUniformBlocks.h"
31 #include "compiler/translator/wgsl/RewritePipelineVariables.h"
32 #include "compiler/translator/wgsl/Utils.h"
33
34 namespace sh
35 {
36 namespace
37 {
38
39 constexpr bool kOutputTreeBeforeTranslation = false;
40 constexpr bool kOutputTranslatedShader = false;
41
42 struct VarDecl
43 {
44 const SymbolType symbolType = SymbolType::Empty;
45 const ImmutableString &symbolName;
46 const TType &type;
47 };
48
IsDefaultUniform(const TType & type)49 bool IsDefaultUniform(const TType &type)
50 {
51 return type.getQualifier() == EvqUniform && type.getInterfaceBlock() == nullptr &&
52 !IsOpaqueType(type.getBasicType());
53 }
54
55 // When emitting a list of statements, this determines whether a semicolon follows the statement.
RequiresSemicolonTerminator(TIntermNode & node)56 bool RequiresSemicolonTerminator(TIntermNode &node)
57 {
58 if (node.getAsBlock())
59 {
60 return false;
61 }
62 if (node.getAsLoopNode())
63 {
64 return false;
65 }
66 if (node.getAsSwitchNode())
67 {
68 return false;
69 }
70 if (node.getAsIfElseNode())
71 {
72 return false;
73 }
74 if (node.getAsFunctionDefinition())
75 {
76 return false;
77 }
78 if (node.getAsCaseNode())
79 {
80 return false;
81 }
82
83 return true;
84 }
85
86 // For pretty formatting of the resulting WGSL text.
NewlinePad(TIntermNode & node)87 bool NewlinePad(TIntermNode &node)
88 {
89 if (node.getAsFunctionDefinition())
90 {
91 return true;
92 }
93 if (TIntermDeclaration *declNode = node.getAsDeclarationNode())
94 {
95 ASSERT(declNode->getChildCount() == 1);
96 TIntermNode &childNode = *declNode->getChildNode(0);
97 if (TIntermSymbol *symbolNode = childNode.getAsSymbolNode())
98 {
99 const TVariable &var = symbolNode->variable();
100 return var.getType().isStructSpecifier();
101 }
102 return false;
103 }
104 return false;
105 }
106
107 // A traverser that generates WGSL as it walks the AST.
108 class OutputWGSLTraverser : public TIntermTraverser
109 {
110 public:
111 OutputWGSLTraverser(TCompiler *compiler,
112 RewritePipelineVarOutput *rewritePipelineVarOutput,
113 UniformBlockMetadata *uniformBlockMetadata);
114 ~OutputWGSLTraverser() override;
115
116 protected:
117 void visitSymbol(TIntermSymbol *node) override;
118 void visitConstantUnion(TIntermConstantUnion *node) override;
119 bool visitSwizzle(Visit visit, TIntermSwizzle *node) override;
120 bool visitBinary(Visit visit, TIntermBinary *node) override;
121 bool visitUnary(Visit visit, TIntermUnary *node) override;
122 bool visitTernary(Visit visit, TIntermTernary *node) override;
123 bool visitIfElse(Visit visit, TIntermIfElse *node) override;
124 bool visitSwitch(Visit visit, TIntermSwitch *node) override;
125 bool visitCase(Visit visit, TIntermCase *node) override;
126 void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
127 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
128 bool visitAggregate(Visit visit, TIntermAggregate *node) override;
129 bool visitBlock(Visit visit, TIntermBlock *node) override;
130 bool visitGlobalQualifierDeclaration(Visit visit,
131 TIntermGlobalQualifierDeclaration *node) override;
132 bool visitDeclaration(Visit visit, TIntermDeclaration *node) override;
133 bool visitLoop(Visit visit, TIntermLoop *node) override;
134 bool visitBranch(Visit visit, TIntermBranch *node) override;
135 void visitPreprocessorDirective(TIntermPreprocessorDirective *node) override;
136
137 private:
138 struct EmitVariableDeclarationConfig
139 {
140 bool isParameter = false;
141 bool disableStructSpecifier = false;
142 bool needsVar = false;
143 bool isGlobalScope = false;
144 };
145
146 void groupedTraverse(TIntermNode &node);
147 void emitNameOf(const VarDecl &decl);
148 void emitBareTypeName(const TType &type);
149 void emitType(const TType &type);
150 void emitSingleConstant(const TConstantUnion *const constUnion);
151 const TConstantUnion *emitConstantUnionArray(const TConstantUnion *const constUnion,
152 const size_t size);
153 const TConstantUnion *emitConstantUnion(const TType &type,
154 const TConstantUnion *constUnionBegin);
155 const TField &getDirectField(const TIntermTyped &fieldsNode, TIntermTyped &indexNode);
156 void emitIndentation();
157 void emitOpenBrace();
158 void emitCloseBrace();
159 bool emitBlock(TSpan<TIntermNode *> nodes);
160 void emitFunctionSignature(const TFunction &func);
161 void emitFunctionReturn(const TFunction &func);
162 void emitFunctionParameter(const TFunction &func, const TVariable ¶m);
163 void emitStructDeclaration(const TType &type);
164 void emitVariableDeclaration(const VarDecl &decl,
165 const EmitVariableDeclarationConfig &evdConfig);
166 void emitArrayIndex(TIntermTyped &leftNode, TIntermTyped &rightNode);
167
168 bool emitForLoop(TIntermLoop *);
169 bool emitWhileLoop(TIntermLoop *);
170 bool emulateDoWhileLoop(TIntermLoop *);
171
172 TInfoSinkBase &mSink;
173 RewritePipelineVarOutput *mRewritePipelineVarOutput;
174 UniformBlockMetadata *mUniformBlockMetadata;
175
176 int mIndentLevel = -1;
177 int mLastIndentationPos = -1;
178 };
179
OutputWGSLTraverser(TCompiler * compiler,RewritePipelineVarOutput * rewritePipelineVarOutput,UniformBlockMetadata * uniformBlockMetadata)180 OutputWGSLTraverser::OutputWGSLTraverser(TCompiler *compiler,
181 RewritePipelineVarOutput *rewritePipelineVarOutput,
182 UniformBlockMetadata *uniformBlockMetadata)
183 : TIntermTraverser(true, false, false),
184 mSink(compiler->getInfoSink().obj),
185 mRewritePipelineVarOutput(rewritePipelineVarOutput),
186 mUniformBlockMetadata(uniformBlockMetadata)
187 {}
188
189 OutputWGSLTraverser::~OutputWGSLTraverser() = default;
190
groupedTraverse(TIntermNode & node)191 void OutputWGSLTraverser::groupedTraverse(TIntermNode &node)
192 {
193 // TODO(anglebug.com/42267100): to make generated code more readable, do not always
194 // emit parentheses like WGSL is some Lisp dialect.
195 const bool emitParens = true;
196
197 if (emitParens)
198 {
199 mSink << "(";
200 }
201
202 node.traverse(this);
203
204 if (emitParens)
205 {
206 mSink << ")";
207 }
208 }
209
emitNameOf(const VarDecl & decl)210 void OutputWGSLTraverser::emitNameOf(const VarDecl &decl)
211 {
212 WriteNameOf(mSink, decl.symbolType, decl.symbolName);
213 }
214
emitIndentation()215 void OutputWGSLTraverser::emitIndentation()
216 {
217 ASSERT(mIndentLevel >= 0);
218
219 if (mLastIndentationPos == mSink.size())
220 {
221 return; // Line is already indented.
222 }
223
224 for (int i = 0; i < mIndentLevel; ++i)
225 {
226 mSink << " ";
227 }
228
229 mLastIndentationPos = mSink.size();
230 }
231
emitOpenBrace()232 void OutputWGSLTraverser::emitOpenBrace()
233 {
234 ASSERT(mIndentLevel >= 0);
235
236 emitIndentation();
237 mSink << "{\n";
238 ++mIndentLevel;
239 }
240
emitCloseBrace()241 void OutputWGSLTraverser::emitCloseBrace()
242 {
243 ASSERT(mIndentLevel >= 1);
244
245 --mIndentLevel;
246 emitIndentation();
247 mSink << "}";
248 }
249
visitSymbol(TIntermSymbol * symbolNode)250 void OutputWGSLTraverser::visitSymbol(TIntermSymbol *symbolNode)
251 {
252
253 const TVariable &var = symbolNode->variable();
254 const TType &type = var.getType();
255 ASSERT(var.symbolType() != SymbolType::Empty);
256
257 if (type.getBasicType() == TBasicType::EbtVoid)
258 {
259 UNREACHABLE();
260 }
261 else
262 {
263 // Accesses of pipeline variables should be rewritten as struct accesses.
264 if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()))
265 {
266 mSink << kBuiltinInputStructName << "." << var.name();
267 }
268 else if (mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
269 {
270 mSink << kBuiltinOutputStructName << "." << var.name();
271 }
272 // Accesses of basic uniforms need to be converted to struct accesses.
273 else if (IsDefaultUniform(type))
274 {
275 mSink << kDefaultUniformBlockVarName << "." << var.name();
276 }
277 else
278 {
279 WriteNameOf(mSink, var);
280 }
281
282 if (var.symbolType() == SymbolType::BuiltIn)
283 {
284 ASSERT(mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
285 mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()) ||
286 var.uniqueId() == BuiltInId::gl_DepthRange);
287 // TODO(anglebug.com/42267100): support gl_DepthRange.
288 // Match the name of the struct field in `mRewritePipelineVarOutput`.
289 mSink << "_";
290 }
291 }
292 }
293
emitSingleConstant(const TConstantUnion * const constUnion)294 void OutputWGSLTraverser::emitSingleConstant(const TConstantUnion *const constUnion)
295 {
296 switch (constUnion->getType())
297 {
298 case TBasicType::EbtBool:
299 {
300 mSink << (constUnion->getBConst() ? "true" : "false");
301 }
302 break;
303
304 case TBasicType::EbtFloat:
305 {
306 float value = constUnion->getFConst();
307 if (std::isnan(value))
308 {
309 UNIMPLEMENTED();
310 // TODO(anglebug.com/42267100): this is not a valid constant in WGPU.
311 // You can't even do something like bitcast<f32>(0xffffffffu).
312 // The WGSL compiler still complains. I think this is because
313 // WGSL supports implementations compiling with -ffastmath and
314 // therefore nans and infinities are assumed to not exist.
315 // See also https://github.com/gpuweb/gpuweb/issues/3749.
316 mSink << "NAN_INVALID";
317 }
318 else if (std::isinf(value))
319 {
320 UNIMPLEMENTED();
321 // see above.
322 mSink << "INFINITY_INVALID";
323 }
324 else
325 {
326 mSink << value << "f";
327 }
328 }
329 break;
330
331 case TBasicType::EbtInt:
332 {
333 mSink << constUnion->getIConst() << "i";
334 }
335 break;
336
337 case TBasicType::EbtUInt:
338 {
339 mSink << constUnion->getUConst() << "u";
340 }
341 break;
342
343 default:
344 {
345 UNIMPLEMENTED();
346 }
347 }
348 }
349
emitConstantUnionArray(const TConstantUnion * const constUnion,const size_t size)350 const TConstantUnion *OutputWGSLTraverser::emitConstantUnionArray(
351 const TConstantUnion *const constUnion,
352 const size_t size)
353 {
354 const TConstantUnion *constUnionIterated = constUnion;
355 for (size_t i = 0; i < size; i++, constUnionIterated++)
356 {
357 emitSingleConstant(constUnionIterated);
358
359 if (i != size - 1)
360 {
361 mSink << ", ";
362 }
363 }
364 return constUnionIterated;
365 }
366
emitConstantUnion(const TType & type,const TConstantUnion * constUnionBegin)367 const TConstantUnion *OutputWGSLTraverser::emitConstantUnion(const TType &type,
368 const TConstantUnion *constUnionBegin)
369 {
370 const TConstantUnion *constUnionCurr = constUnionBegin;
371 const TStructure *structure = type.getStruct();
372 if (structure)
373 {
374 emitType(type);
375 // Structs are constructed with parentheses in WGSL.
376 mSink << "(";
377 // Emit the constructor parameters. Both GLSL and WGSL require there to be the same number
378 // of parameters as struct fields.
379 const TFieldList &fields = structure->fields();
380 for (size_t i = 0; i < fields.size(); ++i)
381 {
382 const TType *fieldType = fields[i]->type();
383 constUnionCurr = emitConstantUnion(*fieldType, constUnionCurr);
384 if (i != fields.size() - 1)
385 {
386 mSink << ", ";
387 }
388 }
389 mSink << ")";
390 }
391 else
392 {
393 size_t size = type.getObjectSize();
394 // If the type's size is more than 1, the type needs to be written with parantheses. This
395 // applies for vectors, matrices, and arrays.
396 bool writeType = size > 1;
397 if (writeType)
398 {
399 emitType(type);
400 mSink << "(";
401 }
402 constUnionCurr = emitConstantUnionArray(constUnionCurr, size);
403 if (writeType)
404 {
405 mSink << ")";
406 }
407 }
408 return constUnionCurr;
409 }
410
visitConstantUnion(TIntermConstantUnion * constValueNode)411 void OutputWGSLTraverser::visitConstantUnion(TIntermConstantUnion *constValueNode)
412 {
413 emitConstantUnion(constValueNode->getType(), constValueNode->getConstantValue());
414 }
415
visitSwizzle(Visit,TIntermSwizzle * swizzleNode)416 bool OutputWGSLTraverser::visitSwizzle(Visit, TIntermSwizzle *swizzleNode)
417 {
418 groupedTraverse(*swizzleNode->getOperand());
419 mSink << ".";
420 swizzleNode->writeOffsetsAsXYZW(&mSink);
421
422 return false;
423 }
424
GetOperatorString(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1,const TType * argType2)425 const char *GetOperatorString(TOperator op,
426 const TType &resultType,
427 const TType *argType0,
428 const TType *argType1,
429 const TType *argType2)
430 {
431 switch (op)
432 {
433 case TOperator::EOpComma:
434 // WGSL does not have a comma operator or any other way to implement "statement list as
435 // an expression", so nested expressions will have to be pulled out into statements.
436 UNIMPLEMENTED();
437 return "TODO_operator";
438 case TOperator::EOpAssign:
439 return "=";
440 case TOperator::EOpInitialize:
441 return "=";
442 // Compound assignments now exist: https://www.w3.org/TR/WGSL/#compound-assignment-sec
443 case TOperator::EOpAddAssign:
444 return "+=";
445 case TOperator::EOpSubAssign:
446 return "-=";
447 case TOperator::EOpMulAssign:
448 return "*=";
449 case TOperator::EOpDivAssign:
450 return "/=";
451 case TOperator::EOpIModAssign:
452 return "%=";
453 case TOperator::EOpBitShiftLeftAssign:
454 return "<<=";
455 case TOperator::EOpBitShiftRightAssign:
456 return ">>=";
457 case TOperator::EOpBitwiseAndAssign:
458 return "&=";
459 case TOperator::EOpBitwiseXorAssign:
460 return "^=";
461 case TOperator::EOpBitwiseOrAssign:
462 return "|=";
463 case TOperator::EOpAdd:
464 return "+";
465 case TOperator::EOpSub:
466 return "-";
467 case TOperator::EOpMul:
468 return "*";
469 case TOperator::EOpDiv:
470 return "/";
471 // TODO(anglebug.com/42267100): Works different from GLSL for negative numbers.
472 // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=not%20WGSL%3B%20etc.-,Inconsistent%20mod/%25%20operator,-At%20first%20glance
473 // GLSL does `x - y * floor(x/y)`, WGSL does x - y * trunc(x/y).
474 case TOperator::EOpIMod:
475 case TOperator::EOpMod:
476 return "%";
477 case TOperator::EOpBitShiftLeft:
478 return "<<";
479 case TOperator::EOpBitShiftRight:
480 return ">>";
481 case TOperator::EOpBitwiseAnd:
482 return "&";
483 case TOperator::EOpBitwiseXor:
484 return "^";
485 case TOperator::EOpBitwiseOr:
486 return "|";
487 case TOperator::EOpLessThan:
488 return "<";
489 case TOperator::EOpGreaterThan:
490 return ">";
491 case TOperator::EOpLessThanEqual:
492 return "<=";
493 case TOperator::EOpGreaterThanEqual:
494 return ">=";
495 // Component-wise comparisons are done with regular infix operators in WGSL:
496 // https://www.w3.org/TR/WGSL/#comparison-expr
497 case TOperator::EOpLessThanComponentWise:
498 return "<";
499 case TOperator::EOpLessThanEqualComponentWise:
500 return "<=";
501 case TOperator::EOpGreaterThanEqualComponentWise:
502 return ">=";
503 case TOperator::EOpGreaterThanComponentWise:
504 return ">";
505 case TOperator::EOpLogicalOr:
506 return "||";
507 // Logical XOR is only applied to boolean expressions so it's the same as "not equals".
508 // Neither short-circuits.
509 case TOperator::EOpLogicalXor:
510 return "!=";
511 case TOperator::EOpLogicalAnd:
512 return "&&";
513 case TOperator::EOpNegative:
514 return "-";
515 case TOperator::EOpPositive:
516 if (argType0->isMatrix())
517 {
518 return "";
519 }
520 return "+";
521 case TOperator::EOpLogicalNot:
522 return "!";
523 // Component-wise not done with normal prefix unary operator in WGSL:
524 // https://www.w3.org/TR/WGSL/#logical-expr
525 case TOperator::EOpNotComponentWise:
526 return "!";
527 case TOperator::EOpBitwiseNot:
528 return "~";
529 // TODO(anglebug.com/42267100): increment operations cannot be used as expressions in WGSL.
530 case TOperator::EOpPostIncrement:
531 return "++";
532 case TOperator::EOpPostDecrement:
533 return "--";
534 case TOperator::EOpPreIncrement:
535 case TOperator::EOpPreDecrement:
536 // TODO(anglebug.com/42267100): pre increments and decrements do not exist in WGSL.
537 UNIMPLEMENTED();
538 return "TODO_operator";
539 case TOperator::EOpVectorTimesScalarAssign:
540 return "*=";
541 case TOperator::EOpVectorTimesMatrixAssign:
542 return "*=";
543 case TOperator::EOpMatrixTimesScalarAssign:
544 return "*=";
545 case TOperator::EOpMatrixTimesMatrixAssign:
546 return "*=";
547 case TOperator::EOpVectorTimesScalar:
548 return "*";
549 case TOperator::EOpVectorTimesMatrix:
550 return "*";
551 case TOperator::EOpMatrixTimesVector:
552 return "*";
553 case TOperator::EOpMatrixTimesScalar:
554 return "*";
555 case TOperator::EOpMatrixTimesMatrix:
556 return "*";
557 case TOperator::EOpEqualComponentWise:
558 return "==";
559 case TOperator::EOpNotEqualComponentWise:
560 return "!=";
561
562 // TODO(anglebug.com/42267100): structs, matrices, and arrays are not comparable with WGSL's
563 // == or !=. Comparing vectors results in a component-wise comparison returning a boolean
564 // vector, which is different from GLSL (which use equal(vec, vec) for component-wise
565 // comparison)
566 case TOperator::EOpEqual:
567 if ((argType0->isVector() && argType1->isVector()) ||
568 (argType0->getStruct() && argType1->getStruct()) ||
569 (argType0->isArray() && argType1->isArray()) ||
570 (argType0->isMatrix() && argType1->isMatrix()))
571
572 {
573 UNIMPLEMENTED();
574 return "TODO_operator";
575 }
576
577 return "==";
578
579 case TOperator::EOpNotEqual:
580 if ((argType0->isVector() && argType1->isVector()) ||
581 (argType0->getStruct() && argType1->getStruct()) ||
582 (argType0->isArray() && argType1->isArray()) ||
583 (argType0->isMatrix() && argType1->isMatrix()))
584 {
585 UNIMPLEMENTED();
586 return "TODO_operator";
587 }
588 return "!=";
589
590 case TOperator::EOpKill:
591 case TOperator::EOpReturn:
592 case TOperator::EOpBreak:
593 case TOperator::EOpContinue:
594 // These should all be emitted in visitBranch().
595 UNREACHABLE();
596 return "UNREACHABLE_operator";
597 case TOperator::EOpRadians:
598 return "radians";
599 case TOperator::EOpDegrees:
600 return "degrees";
601 case TOperator::EOpAtan:
602 return argType1 == nullptr ? "atan" : "atan2";
603 case TOperator::EOpRefract:
604 return argType0->isVector() ? "refract" : "TODO_operator";
605 case TOperator::EOpDistance:
606 return "distance";
607 case TOperator::EOpLength:
608 return "length";
609 case TOperator::EOpDot:
610 return argType0->isVector() ? "dot" : "*";
611 case TOperator::EOpNormalize:
612 return argType0->isVector() ? "normalize" : "sign";
613 case TOperator::EOpFaceforward:
614 return argType0->isVector() ? "faceForward" : "TODO_Operator";
615 case TOperator::EOpReflect:
616 return argType0->isVector() ? "reflect" : "TODO_Operator";
617 case TOperator::EOpMatrixCompMult:
618 return "TODO_Operator";
619 case TOperator::EOpOuterProduct:
620 return "TODO_Operator";
621 case TOperator::EOpSign:
622 return "sign";
623
624 case TOperator::EOpAbs:
625 return "abs";
626 case TOperator::EOpAll:
627 return "all";
628 case TOperator::EOpAny:
629 return "any";
630 case TOperator::EOpSin:
631 return "sin";
632 case TOperator::EOpCos:
633 return "cos";
634 case TOperator::EOpTan:
635 return "tan";
636 case TOperator::EOpAsin:
637 return "asin";
638 case TOperator::EOpAcos:
639 return "acos";
640 case TOperator::EOpSinh:
641 return "sinh";
642 case TOperator::EOpCosh:
643 return "cosh";
644 case TOperator::EOpTanh:
645 return "tanh";
646 case TOperator::EOpAsinh:
647 return "asinh";
648 case TOperator::EOpAcosh:
649 return "acosh";
650 case TOperator::EOpAtanh:
651 return "atanh";
652 case TOperator::EOpFma:
653 return "fma";
654 // TODO(anglebug.com/42267100): Won't accept pow(vec<f32>, f32).
655 // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=Similarly%20pow(vec3%3Cf32%3E%2C%20f32)%20works%20in%20GLSL%20but%20not%20WGSL
656 case TOperator::EOpPow:
657 return "pow"; // GLSL's pow excludes negative x
658 case TOperator::EOpExp:
659 return "exp";
660 case TOperator::EOpExp2:
661 return "exp2";
662 case TOperator::EOpLog:
663 return "log";
664 case TOperator::EOpLog2:
665 return "log2";
666 case TOperator::EOpSqrt:
667 return "sqrt";
668 case TOperator::EOpFloor:
669 return "floor";
670 case TOperator::EOpTrunc:
671 return "trunc";
672 case TOperator::EOpCeil:
673 return "ceil";
674 case TOperator::EOpFract:
675 return "fract";
676 case TOperator::EOpMin:
677 return "min";
678 case TOperator::EOpMax:
679 return "max";
680 case TOperator::EOpRound:
681 return "round"; // TODO(anglebug.com/42267100): this is wrong and must round away from
682 // zero if there is a tie. This always rounds to the even number.
683 case TOperator::EOpRoundEven:
684 return "round";
685 // TODO(anglebug.com/42267100):
686 // https://github.com/gpuweb/gpuweb/discussions/2204#:~:text=clamp(vec2%3Cf32%3E%2C%20f32%2C%20f32)%20works%20in%20GLSL%20but%20not%20WGSL%3B%20etc.
687 // Need to expand clamp(vec<f32>, low : f32, high : f32) ->
688 // clamp(vec<f32>, vec<f32>(low), vec<f32>(high))
689 case TOperator::EOpClamp:
690 return "clamp";
691 case TOperator::EOpSaturate:
692 return "saturate";
693 case TOperator::EOpMix:
694 if (!argType1->isScalar() && argType2 && argType2->getBasicType() == EbtBool)
695 {
696 return "TODO_Operator";
697 }
698 return "mix";
699 case TOperator::EOpStep:
700 return "step";
701 case TOperator::EOpSmoothstep:
702 return "smoothstep";
703 case TOperator::EOpModf:
704 UNIMPLEMENTED(); // TODO(anglebug.com/42267100): in WGSL this returns a struct, GLSL it
705 // uses a return value and an outparam
706 return "modf";
707 case TOperator::EOpIsnan:
708 case TOperator::EOpIsinf:
709 UNIMPLEMENTED(); // TODO(anglebug.com/42267100): WGSL does not allow NaNs or infinity.
710 // What to do about shaders that require this?
711 // Implementations are allowed to assume overflow, infinities, and NaNs are not present
712 // at runtime, however. https://www.w3.org/TR/WGSL/#floating-point-evaluation
713 return "TODO_Operator";
714 case TOperator::EOpLdexp:
715 // TODO(anglebug.com/42267100): won't accept first arg vector, second arg scalar
716 return "ldexp";
717 case TOperator::EOpFrexp:
718 return "frexp"; // TODO(anglebug.com/42267100): returns a struct
719 case TOperator::EOpInversesqrt:
720 return "inverseSqrt";
721 case TOperator::EOpCross:
722 return "cross";
723 // TODO(anglebug.com/42267100): are these the same? dpdxCoarse() vs dpdxFine()?
724 case TOperator::EOpDFdx:
725 return "dpdx";
726 case TOperator::EOpDFdy:
727 return "dpdy";
728 case TOperator::EOpFwidth:
729 return "fwidth";
730 case TOperator::EOpTranspose:
731 return "transpose";
732 case TOperator::EOpDeterminant:
733 return "determinant";
734
735 case TOperator::EOpInverse:
736 return "TODO_Operator"; // No builtin invert().
737 // https://github.com/gpuweb/gpuweb/issues/4115
738
739 // TODO(anglebug.com/42267100): these interpolateAt*() are not builtin
740 case TOperator::EOpInterpolateAtCentroid:
741 return "TODO_Operator";
742 case TOperator::EOpInterpolateAtSample:
743 return "TODO_Operator";
744 case TOperator::EOpInterpolateAtOffset:
745 return "TODO_Operator";
746 case TOperator::EOpInterpolateAtCenter:
747 return "TODO_Operator";
748
749 case TOperator::EOpFloatBitsToInt:
750 case TOperator::EOpFloatBitsToUint:
751 case TOperator::EOpIntBitsToFloat:
752 case TOperator::EOpUintBitsToFloat:
753 {
754 #define BITCAST_SCALAR() \
755 do \
756 switch (resultType.getBasicType()) \
757 { \
758 case TBasicType::EbtInt: \
759 return "bitcast<i32>"; \
760 case TBasicType::EbtUInt: \
761 return "bitcast<u32>"; \
762 case TBasicType::EbtFloat: \
763 return "bitcast<f32>"; \
764 default: \
765 UNIMPLEMENTED(); \
766 return "TOperator_TODO"; \
767 } \
768 while (false)
769
770 #define BITCAST_VECTOR(vecSize) \
771 do \
772 switch (resultType.getBasicType()) \
773 { \
774 case TBasicType::EbtInt: \
775 return "bitcast<vec" vecSize "<i32>>"; \
776 case TBasicType::EbtUInt: \
777 return "bitcast<vec" vecSize "<u32>>"; \
778 case TBasicType::EbtFloat: \
779 return "bitcast<vec" vecSize "<f32>>"; \
780 default: \
781 UNIMPLEMENTED(); \
782 return "TOperator_TODO"; \
783 } \
784 while (false)
785
786 if (resultType.isScalar())
787 {
788 BITCAST_SCALAR();
789 }
790 else if (resultType.isVector())
791 {
792 switch (resultType.getNominalSize())
793 {
794 case 2:
795 BITCAST_VECTOR("2");
796 case 3:
797 BITCAST_VECTOR("3");
798 case 4:
799 BITCAST_VECTOR("4");
800 default:
801 UNREACHABLE();
802 return nullptr;
803 }
804 }
805 else
806 {
807 UNIMPLEMENTED();
808 return "TOperator_TODO";
809 }
810
811 #undef BITCAST_SCALAR
812 #undef BITCAST_VECTOR
813 }
814
815 case TOperator::EOpPackUnorm2x16:
816 return "pack2x16unorm";
817 case TOperator::EOpPackSnorm2x16:
818 return "pack2x16snorm";
819
820 case TOperator::EOpPackUnorm4x8:
821 return "pack4x8unorm";
822 case TOperator::EOpPackSnorm4x8:
823 return "pack4x8snorm";
824
825 case TOperator::EOpUnpackUnorm2x16:
826 return "unpack2x16unorm";
827 case TOperator::EOpUnpackSnorm2x16:
828 return "unpack2x16snorm";
829
830 case TOperator::EOpUnpackUnorm4x8:
831 return "unpack4x8unorm";
832 case TOperator::EOpUnpackSnorm4x8:
833 return "unpack4x8snorm";
834
835 case TOperator::EOpPackHalf2x16:
836 return "pack2x16float";
837 case TOperator::EOpUnpackHalf2x16:
838 return "unpack2x16float";
839
840 case TOperator::EOpBarrier:
841 UNREACHABLE();
842 return "TOperator_TODO";
843 case TOperator::EOpMemoryBarrier:
844 // TODO(anglebug.com/42267100): does this exist in WGPU? Device-scoped memory barrier?
845 // Maybe storageBarrier()?
846 UNREACHABLE();
847 return "TOperator_TODO";
848 case TOperator::EOpGroupMemoryBarrier:
849 return "workgroupBarrier";
850 case TOperator::EOpMemoryBarrierAtomicCounter:
851 case TOperator::EOpMemoryBarrierBuffer:
852 case TOperator::EOpMemoryBarrierShared:
853 UNREACHABLE();
854 return "TOperator_TODO";
855 case TOperator::EOpAtomicAdd:
856 return "atomicAdd";
857 case TOperator::EOpAtomicMin:
858 return "atomicMin";
859 case TOperator::EOpAtomicMax:
860 return "atomicMax";
861 case TOperator::EOpAtomicAnd:
862 return "atomicAnd";
863 case TOperator::EOpAtomicOr:
864 return "atomicOr";
865 case TOperator::EOpAtomicXor:
866 return "atomicXor";
867 case TOperator::EOpAtomicExchange:
868 return "atomicExchange";
869 case TOperator::EOpAtomicCompSwap:
870 return "atomicCompareExchangeWeak"; // TODO(anglebug.com/42267100): returns a struct.
871 case TOperator::EOpBitfieldExtract:
872 case TOperator::EOpBitfieldInsert:
873 case TOperator::EOpBitfieldReverse:
874 case TOperator::EOpBitCount:
875 case TOperator::EOpFindLSB:
876 case TOperator::EOpFindMSB:
877 case TOperator::EOpUaddCarry:
878 case TOperator::EOpUsubBorrow:
879 case TOperator::EOpUmulExtended:
880 case TOperator::EOpImulExtended:
881 case TOperator::EOpEmitVertex:
882 case TOperator::EOpEndPrimitive:
883 case TOperator::EOpArrayLength:
884 UNIMPLEMENTED();
885 return "TOperator_TODO";
886
887 case TOperator::EOpNull:
888 case TOperator::EOpConstruct:
889 case TOperator::EOpCallFunctionInAST:
890 case TOperator::EOpCallInternalRawFunction:
891 case TOperator::EOpIndexDirect:
892 case TOperator::EOpIndexIndirect:
893 case TOperator::EOpIndexDirectStruct:
894 case TOperator::EOpIndexDirectInterfaceBlock:
895 UNREACHABLE();
896 return nullptr;
897 default:
898 // Any other built-in function.
899 return nullptr;
900 }
901 }
902
IsSymbolicOperator(TOperator op,const TType & resultType,const TType * argType0,const TType * argType1)903 bool IsSymbolicOperator(TOperator op,
904 const TType &resultType,
905 const TType *argType0,
906 const TType *argType1)
907 {
908 const char *operatorString = GetOperatorString(op, resultType, argType0, argType1, nullptr);
909 if (operatorString == nullptr)
910 {
911 return false;
912 }
913 return !std::isalnum(operatorString[0]);
914 }
915
getDirectField(const TIntermTyped & fieldsNode,TIntermTyped & indexNode)916 const TField &OutputWGSLTraverser::getDirectField(const TIntermTyped &fieldsNode,
917 TIntermTyped &indexNode)
918 {
919 const TType &fieldsType = fieldsNode.getType();
920
921 const TFieldListCollection *fieldListCollection = fieldsType.getStruct();
922 if (fieldListCollection == nullptr)
923 {
924 fieldListCollection = fieldsType.getInterfaceBlock();
925 }
926 ASSERT(fieldListCollection);
927
928 const TIntermConstantUnion *indexNodeAsConstantUnion = indexNode.getAsConstantUnion();
929 ASSERT(indexNodeAsConstantUnion);
930 const TConstantUnion &index = *indexNodeAsConstantUnion->getConstantValue();
931
932 ASSERT(index.getType() == TBasicType::EbtInt);
933
934 const TFieldList &fieldList = fieldListCollection->fields();
935 const int indexVal = index.getIConst();
936 const TField &field = *fieldList[indexVal];
937
938 return field;
939 }
940
emitArrayIndex(TIntermTyped & leftNode,TIntermTyped & rightNode)941 void OutputWGSLTraverser::emitArrayIndex(TIntermTyped &leftNode, TIntermTyped &rightNode)
942 {
943
944 {
945 TType leftType = leftNode.getType();
946 groupedTraverse(leftNode);
947 mSink << "[";
948 const TConstantUnion *constIndex = rightNode.getConstantValue();
949 // If the array index is a constant that we can statically verify is within array
950 // bounds, just emit that constant.
951 if (!leftType.isUnsizedArray() && constIndex != nullptr &&
952 constIndex->getType() == EbtInt && constIndex->getIConst() >= 0 &&
953 constIndex->getIConst() < static_cast<int>(leftType.isArray()
954 ? leftType.getOutermostArraySize()
955 : leftType.getNominalSize()))
956 {
957 emitSingleConstant(constIndex);
958 }
959 else
960 {
961 // If the array index is not a constant within the bounds of the array, clamp the
962 // index.
963 mSink << "clamp(";
964 groupedTraverse(rightNode);
965 mSink << ", 0, ";
966 // Now find the array size and clamp it.
967 if (leftType.isUnsizedArray())
968 {
969 // TODO(anglebug.com/42267100): This is a bug to traverse the `leftNode` a
970 // second time if `leftNode` has side effects (and could also have performance
971 // implications). This should be stored in a temporary variable. This might also
972 // be a bug in the MSL shader compiler.
973 mSink << "arrayLength(&";
974 groupedTraverse(leftNode);
975 mSink << ")";
976 }
977 else
978 {
979 uint32_t maxSize;
980 if (leftType.isArray())
981 {
982 maxSize = leftType.getOutermostArraySize() - 1;
983 }
984 else
985 {
986 maxSize = leftType.getNominalSize() - 1;
987 }
988 mSink << maxSize;
989 }
990 // End the clamp() function.
991 mSink << ")";
992 }
993 // End the array index operation.
994 mSink << "]";
995 }
996 }
997
visitBinary(Visit,TIntermBinary * binaryNode)998 bool OutputWGSLTraverser::visitBinary(Visit, TIntermBinary *binaryNode)
999 {
1000 const TOperator op = binaryNode->getOp();
1001 TIntermTyped &leftNode = *binaryNode->getLeft();
1002 TIntermTyped &rightNode = *binaryNode->getRight();
1003
1004 switch (op)
1005 {
1006 case TOperator::EOpIndexDirectStruct:
1007 case TOperator::EOpIndexDirectInterfaceBlock:
1008 groupedTraverse(leftNode);
1009 mSink << ".";
1010 WriteNameOf(mSink, getDirectField(leftNode, rightNode));
1011 break;
1012
1013 case TOperator::EOpIndexDirect:
1014 case TOperator::EOpIndexIndirect:
1015 emitArrayIndex(leftNode, rightNode);
1016 break;
1017
1018 default:
1019 {
1020 const TType &resultType = binaryNode->getType();
1021 const TType &leftType = leftNode.getType();
1022 const TType &rightType = rightNode.getType();
1023
1024 // x * y, x ^ y, etc.
1025 if (IsSymbolicOperator(op, resultType, &leftType, &rightType))
1026 {
1027 groupedTraverse(leftNode);
1028 if (op != TOperator::EOpComma)
1029 {
1030 mSink << " ";
1031 }
1032 mSink << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << " ";
1033 groupedTraverse(rightNode);
1034 }
1035 // E.g. builtin function calls
1036 else
1037 {
1038 mSink << GetOperatorString(op, resultType, &leftType, &rightType, nullptr) << "(";
1039 leftNode.traverse(this);
1040 mSink << ", ";
1041 rightNode.traverse(this);
1042 mSink << ")";
1043 }
1044 }
1045 }
1046
1047 return false;
1048 }
1049
IsPostfix(TOperator op)1050 bool IsPostfix(TOperator op)
1051 {
1052 switch (op)
1053 {
1054 case TOperator::EOpPostIncrement:
1055 case TOperator::EOpPostDecrement:
1056 return true;
1057
1058 default:
1059 return false;
1060 }
1061 }
1062
visitUnary(Visit,TIntermUnary * unaryNode)1063 bool OutputWGSLTraverser::visitUnary(Visit, TIntermUnary *unaryNode)
1064 {
1065 const TOperator op = unaryNode->getOp();
1066 const TType &resultType = unaryNode->getType();
1067
1068 TIntermTyped &arg = *unaryNode->getOperand();
1069 const TType &argType = arg.getType();
1070
1071 const char *name = GetOperatorString(op, resultType, &argType, nullptr, nullptr);
1072
1073 // Examples: -x, ~x, ~x
1074 if (IsSymbolicOperator(op, resultType, &argType, nullptr))
1075 {
1076 const bool postfix = IsPostfix(op);
1077 if (!postfix)
1078 {
1079 mSink << name;
1080 }
1081 groupedTraverse(arg);
1082 if (postfix)
1083 {
1084 mSink << name;
1085 }
1086 }
1087 else
1088 {
1089 mSink << name << "(";
1090 arg.traverse(this);
1091 mSink << ")";
1092 }
1093
1094 return false;
1095 }
1096
visitTernary(Visit,TIntermTernary * conditionalNode)1097 bool OutputWGSLTraverser::visitTernary(Visit, TIntermTernary *conditionalNode)
1098 {
1099 // WGSL does not have a ternary. https://github.com/gpuweb/gpuweb/issues/3747
1100 // The select() builtin is not short circuiting. Maybe we can get if () {} else {} as an
1101 // expression, which would also solve the comma operator problem.
1102 // TODO(anglebug.com/42267100): as mentioned above this is not correct if the operands have side
1103 // effects. Even if they don't have side effects it could have performance implications.
1104 mSink << "select(";
1105 groupedTraverse(*conditionalNode->getTrueExpression());
1106 mSink << ", ";
1107 groupedTraverse(*conditionalNode->getFalseExpression());
1108 mSink << ", ";
1109 groupedTraverse(*conditionalNode->getCondition());
1110 mSink << ")";
1111
1112 return false;
1113 }
1114
visitIfElse(Visit,TIntermIfElse * ifThenElseNode)1115 bool OutputWGSLTraverser::visitIfElse(Visit, TIntermIfElse *ifThenElseNode)
1116 {
1117 TIntermTyped &condNode = *ifThenElseNode->getCondition();
1118 TIntermBlock *thenNode = ifThenElseNode->getTrueBlock();
1119 TIntermBlock *elseNode = ifThenElseNode->getFalseBlock();
1120
1121 mSink << "if (";
1122 condNode.traverse(this);
1123 mSink << ")";
1124
1125 if (thenNode)
1126 {
1127 mSink << "\n";
1128 thenNode->traverse(this);
1129 }
1130 else
1131 {
1132 mSink << " {}";
1133 }
1134
1135 if (elseNode)
1136 {
1137 mSink << "\n";
1138 emitIndentation();
1139 mSink << "else\n";
1140 elseNode->traverse(this);
1141 }
1142
1143 return false;
1144 }
1145
visitSwitch(Visit,TIntermSwitch * switchNode)1146 bool OutputWGSLTraverser::visitSwitch(Visit, TIntermSwitch *switchNode)
1147 {
1148 TIntermBlock &stmtList = *switchNode->getStatementList();
1149
1150 emitIndentation();
1151 mSink << "switch ";
1152 switchNode->getInit()->traverse(this);
1153 mSink << "\n";
1154
1155 emitOpenBrace();
1156
1157 // TODO(anglebug.com/42267100): Case statements that fall through need to combined into a single
1158 // case statement with multiple labels.
1159
1160 const size_t stmtCount = stmtList.getChildCount();
1161 bool inCaseList = false;
1162 size_t currStmt = 0;
1163 while (currStmt < stmtCount)
1164 {
1165 TIntermNode &stmtNode = *stmtList.getChildNode(currStmt);
1166 TIntermCase *caseNode = stmtNode.getAsCaseNode();
1167 if (caseNode)
1168 {
1169 if (inCaseList)
1170 {
1171 mSink << ", ";
1172 }
1173 else
1174 {
1175 emitIndentation();
1176 mSink << "case ";
1177 inCaseList = true;
1178 }
1179 caseNode->traverse(this);
1180
1181 // Process the next statement.
1182 currStmt++;
1183 }
1184 else
1185 {
1186 // The current statement is not a case statement, end the current case list and emit all
1187 // the code until the next case statement. WGSL requires braces around the case
1188 // statement's code.
1189 ASSERT(inCaseList);
1190 inCaseList = false;
1191 mSink << ":\n";
1192
1193 // Count the statements until the next case (or the end of the switch) and emit them as
1194 // a block. This assumes that the current statement list will never fallthrough to the
1195 // next case statement.
1196 size_t nextCaseStmt = currStmt + 1;
1197 for (;
1198 nextCaseStmt < stmtCount && !stmtList.getChildNode(nextCaseStmt)->getAsCaseNode();
1199 nextCaseStmt++)
1200 {
1201 }
1202 TSpan<TIntermNode *> stmtListView(&stmtList.getSequence()->at(currStmt),
1203 nextCaseStmt - currStmt);
1204 emitBlock(stmtListView);
1205 mSink << "\n";
1206
1207 // Skip to the next case statement.
1208 currStmt = nextCaseStmt;
1209 }
1210 }
1211
1212 emitCloseBrace();
1213
1214 return false;
1215 }
1216
visitCase(Visit,TIntermCase * caseNode)1217 bool OutputWGSLTraverser::visitCase(Visit, TIntermCase *caseNode)
1218 {
1219 // "case" will have been emitted in the visitSwitch() override.
1220
1221 if (caseNode->hasCondition())
1222 {
1223 TIntermTyped *condExpr = caseNode->getCondition();
1224 condExpr->traverse(this);
1225 }
1226 else
1227 {
1228 mSink << "default";
1229 }
1230
1231 return false;
1232 }
1233
emitFunctionReturn(const TFunction & func)1234 void OutputWGSLTraverser::emitFunctionReturn(const TFunction &func)
1235 {
1236 const TType &returnType = func.getReturnType();
1237 if (returnType.getBasicType() == EbtVoid)
1238 {
1239 return;
1240 }
1241 mSink << " -> ";
1242 emitType(returnType);
1243 }
1244
1245 // TODO(anglebug.com/42267100): Function overloads are not supported in WGSL, so function names
1246 // should either be emitted mangled or overloaded functions should be renamed in the AST as a
1247 // pre-pass. As of Apr 2024, WGSL function overloads are "not coming soon"
1248 // (https://github.com/gpuweb/gpuweb/issues/876).
emitFunctionSignature(const TFunction & func)1249 void OutputWGSLTraverser::emitFunctionSignature(const TFunction &func)
1250 {
1251 mSink << "fn ";
1252
1253 WriteNameOf(mSink, func);
1254 mSink << "(";
1255
1256 bool emitComma = false;
1257 const size_t paramCount = func.getParamCount();
1258 for (size_t i = 0; i < paramCount; ++i)
1259 {
1260 if (emitComma)
1261 {
1262 mSink << ", ";
1263 }
1264 emitComma = true;
1265
1266 const TVariable ¶m = *func.getParam(i);
1267 emitFunctionParameter(func, param);
1268 }
1269
1270 mSink << ")";
1271
1272 emitFunctionReturn(func);
1273 }
1274
emitFunctionParameter(const TFunction & func,const TVariable & param)1275 void OutputWGSLTraverser::emitFunctionParameter(const TFunction &func, const TVariable ¶m)
1276 {
1277 // TODO(anglebug.com/42267100): function parameters are immutable and will need to be renamed if
1278 // they are mutated.
1279 EmitVariableDeclarationConfig evdConfig;
1280 evdConfig.isParameter = true;
1281 emitVariableDeclaration({param.symbolType(), param.name(), param.getType()}, evdConfig);
1282 }
1283
visitFunctionPrototype(TIntermFunctionPrototype * funcProtoNode)1284 void OutputWGSLTraverser::visitFunctionPrototype(TIntermFunctionPrototype *funcProtoNode)
1285 {
1286 const TFunction &func = *funcProtoNode->getFunction();
1287
1288 emitIndentation();
1289 // TODO(anglebug.com/42267100): output correct signature for main() if main() is declared as a
1290 // function prototype, or perhaps just emit nothing.
1291 emitFunctionSignature(func);
1292 }
1293
visitFunctionDefinition(Visit,TIntermFunctionDefinition * funcDefNode)1294 bool OutputWGSLTraverser::visitFunctionDefinition(Visit, TIntermFunctionDefinition *funcDefNode)
1295 {
1296 const TFunction &func = *funcDefNode->getFunction();
1297 TIntermBlock &body = *funcDefNode->getBody();
1298
1299 emitIndentation();
1300 emitFunctionSignature(func);
1301 mSink << "\n";
1302 body.traverse(this);
1303
1304 return false;
1305 }
1306
visitAggregate(Visit,TIntermAggregate * aggregateNode)1307 bool OutputWGSLTraverser::visitAggregate(Visit, TIntermAggregate *aggregateNode)
1308 {
1309 const TIntermSequence &args = *aggregateNode->getSequence();
1310
1311 auto emitArgList = [&]() {
1312 mSink << "(";
1313
1314 bool emitComma = false;
1315 for (TIntermNode *arg : args)
1316 {
1317 if (emitComma)
1318 {
1319 mSink << ", ";
1320 }
1321 emitComma = true;
1322 arg->traverse(this);
1323 }
1324
1325 mSink << ")";
1326 };
1327
1328 const TType &retType = aggregateNode->getType();
1329
1330 if (aggregateNode->isConstructor())
1331 {
1332
1333 emitType(retType);
1334 emitArgList();
1335
1336 return false;
1337 }
1338 else
1339 {
1340 const TOperator op = aggregateNode->getOp();
1341 switch (op)
1342 {
1343 case TOperator::EOpCallFunctionInAST:
1344 WriteNameOf(mSink, *aggregateNode->getFunction());
1345 emitArgList();
1346 return false;
1347
1348 default:
1349 // Do not allow raw function calls, i.e. calls to functions
1350 // not present in the AST.
1351 ASSERT(op != TOperator::EOpCallInternalRawFunction);
1352 auto getArgType = [&](size_t index) -> const TType * {
1353 if (index < args.size())
1354 {
1355 TIntermTyped *arg = args[index]->getAsTyped();
1356 ASSERT(arg);
1357 return &arg->getType();
1358 }
1359 return nullptr;
1360 };
1361
1362 const TType *argType0 = getArgType(0);
1363 const TType *argType1 = getArgType(1);
1364 const TType *argType2 = getArgType(2);
1365
1366 const char *opName = GetOperatorString(op, retType, argType0, argType1, argType2);
1367
1368 if (IsSymbolicOperator(op, retType, argType0, argType1))
1369 {
1370 switch (args.size())
1371 {
1372 case 1:
1373 {
1374 TIntermNode &operandNode = *aggregateNode->getChildNode(0);
1375 if (IsPostfix(op))
1376 {
1377 mSink << opName;
1378 groupedTraverse(operandNode);
1379 }
1380 else
1381 {
1382 groupedTraverse(operandNode);
1383 mSink << opName;
1384 }
1385 return false;
1386 }
1387
1388 case 2:
1389 {
1390 // symbolic operators with 2 args are emitted with infix notation.
1391 TIntermNode &leftNode = *aggregateNode->getChildNode(0);
1392 TIntermNode &rightNode = *aggregateNode->getChildNode(1);
1393 groupedTraverse(leftNode);
1394 mSink << " " << opName << " ";
1395 groupedTraverse(rightNode);
1396 return false;
1397 }
1398
1399 default:
1400 UNREACHABLE();
1401 return false;
1402 }
1403 }
1404 else
1405 {
1406 if (opName == nullptr)
1407 {
1408 // TODO(anglebug.com/42267100): opName should not be allowed to be nullptr
1409 // here, but for now not all builtins are mapped to a string.
1410 opName = "TODO_Operator";
1411 }
1412 // If the operator is not symbolic then it is a builtin that uses function call
1413 // syntax: builtin(arg1, arg2, ..);
1414 mSink << opName;
1415 emitArgList();
1416 return false;
1417 }
1418 }
1419 }
1420 }
1421
emitBlock(TSpan<TIntermNode * > nodes)1422 bool OutputWGSLTraverser::emitBlock(TSpan<TIntermNode *> nodes)
1423 {
1424 ASSERT(mIndentLevel >= -1);
1425 const bool isGlobalScope = mIndentLevel == -1;
1426
1427 if (isGlobalScope)
1428 {
1429 ++mIndentLevel;
1430 }
1431 else
1432 {
1433 emitOpenBrace();
1434 }
1435
1436 TIntermNode *prevStmtNode = nullptr;
1437
1438 const size_t stmtCount = nodes.size();
1439 for (size_t i = 0; i < stmtCount; ++i)
1440 {
1441 TIntermNode &stmtNode = *nodes[i];
1442
1443 if (isGlobalScope && prevStmtNode && (NewlinePad(*prevStmtNode) || NewlinePad(stmtNode)))
1444 {
1445 mSink << "\n";
1446 }
1447 const bool isCase = stmtNode.getAsCaseNode();
1448 mIndentLevel -= isCase;
1449 emitIndentation();
1450 mIndentLevel += isCase;
1451 stmtNode.traverse(this);
1452 if (RequiresSemicolonTerminator(stmtNode))
1453 {
1454 mSink << ";";
1455 }
1456 mSink << "\n";
1457
1458 prevStmtNode = &stmtNode;
1459 }
1460
1461 if (isGlobalScope)
1462 {
1463 ASSERT(mIndentLevel == 0);
1464 --mIndentLevel;
1465 }
1466 else
1467 {
1468 emitCloseBrace();
1469 }
1470
1471 return false;
1472 }
1473
visitBlock(Visit,TIntermBlock * blockNode)1474 bool OutputWGSLTraverser::visitBlock(Visit, TIntermBlock *blockNode)
1475 {
1476 return emitBlock(TSpan(blockNode->getSequence()->data(), blockNode->getSequence()->size()));
1477 }
1478
visitGlobalQualifierDeclaration(Visit,TIntermGlobalQualifierDeclaration *)1479 bool OutputWGSLTraverser::visitGlobalQualifierDeclaration(Visit,
1480 TIntermGlobalQualifierDeclaration *)
1481 {
1482 return false;
1483 }
1484
emitStructDeclaration(const TType & type)1485 void OutputWGSLTraverser::emitStructDeclaration(const TType &type)
1486 {
1487 ASSERT(type.getBasicType() == TBasicType::EbtStruct);
1488 ASSERT(type.isStructSpecifier());
1489
1490 mSink << "struct ";
1491 emitBareTypeName(type);
1492
1493 mSink << "\n";
1494 emitOpenBrace();
1495
1496 const TStructure &structure = *type.getStruct();
1497 bool isInUniformAddressSpace =
1498 mUniformBlockMetadata->structsInUniformAddressSpace.count(structure.uniqueId().get()) != 0;
1499
1500 bool alignTo16InUniformAddressSpace = true;
1501 for (const TField *field : structure.fields())
1502 {
1503 emitIndentation();
1504 // If this struct is used in the uniform address space, it must obey the uniform address
1505 // space's layout constaints (https://www.w3.org/TR/WGSL/#address-space-layout-constraints).
1506 // WGSL's address space layout constraints nearly match std140, and the places they don't
1507 // are handled elsewhere.
1508 if (isInUniformAddressSpace)
1509 {
1510 // Here, the field must be aligned to 16 if:
1511 // 1. The field is a struct or array
1512 // 2. The previous field is a struct
1513 // 3. The field is the first in the struct (for convenience).
1514 if (field->type()->getStruct() || field->type()->isArray())
1515 {
1516 alignTo16InUniformAddressSpace = true;
1517 }
1518 if (alignTo16InUniformAddressSpace)
1519 {
1520 mSink << "@align(16) ";
1521 }
1522
1523 // If this field is a struct, the next member should be aligned to 16.
1524 alignTo16InUniformAddressSpace = field->type()->getStruct();
1525 }
1526
1527 // TODO(anglebug.com/42267100): emit qualifiers.
1528 EmitVariableDeclarationConfig evdConfig;
1529 evdConfig.disableStructSpecifier = true;
1530 emitVariableDeclaration({field->symbolType(), field->name(), *field->type()}, evdConfig);
1531 mSink << ",\n";
1532 }
1533
1534 emitCloseBrace();
1535 }
1536
emitVariableDeclaration(const VarDecl & decl,const EmitVariableDeclarationConfig & evdConfig)1537 void OutputWGSLTraverser::emitVariableDeclaration(const VarDecl &decl,
1538 const EmitVariableDeclarationConfig &evdConfig)
1539 {
1540 const TBasicType basicType = decl.type.getBasicType();
1541
1542 if (decl.type.getQualifier() == EvqUniform)
1543 {
1544 // Uniforms are declared in a pre-pass, and don't need to be outputted here.
1545 return;
1546 }
1547
1548 if (basicType == TBasicType::EbtStruct && decl.type.isStructSpecifier() &&
1549 !evdConfig.disableStructSpecifier)
1550 {
1551 // TODO(anglebug.com/42267100): in WGSL structs probably can't be declared in
1552 // function parameters or in uniform declarations or in variable declarations, or
1553 // anonymously either within other structs or within a variable declaration. Handle
1554 // these with the same AST pre-passes as other shader translators.
1555 ASSERT(!evdConfig.isParameter);
1556 emitStructDeclaration(decl.type);
1557 if (decl.symbolType != SymbolType::Empty)
1558 {
1559 mSink << " ";
1560 emitNameOf(decl);
1561 }
1562 return;
1563 }
1564
1565 ASSERT(basicType == TBasicType::EbtStruct || decl.symbolType != SymbolType::Empty ||
1566 evdConfig.isParameter);
1567
1568 if (evdConfig.needsVar)
1569 {
1570 // "const" and "let" probably don't need to be ever emitted because they are more for
1571 // readability, and the GLSL compiler constant folds most (all?) the consts anyway.
1572 mSink << "var";
1573 // TODO(anglebug.com/42267100): <workgroup> or <storage>?
1574 if (evdConfig.isGlobalScope)
1575 {
1576 mSink << "<private>";
1577 }
1578 mSink << " ";
1579 }
1580 else
1581 {
1582 ASSERT(!evdConfig.isGlobalScope);
1583 }
1584
1585 if (decl.symbolType != SymbolType::Empty)
1586 {
1587 emitNameOf(decl);
1588 }
1589 mSink << " : ";
1590 emitType(decl.type);
1591 }
1592
visitDeclaration(Visit,TIntermDeclaration * declNode)1593 bool OutputWGSLTraverser::visitDeclaration(Visit, TIntermDeclaration *declNode)
1594 {
1595 ASSERT(declNode->getChildCount() == 1);
1596 TIntermNode &node = *declNode->getChildNode(0);
1597
1598 EmitVariableDeclarationConfig evdConfig;
1599 evdConfig.needsVar = true;
1600 evdConfig.isGlobalScope = mIndentLevel == 0;
1601
1602 if (TIntermSymbol *symbolNode = node.getAsSymbolNode())
1603 {
1604 const TVariable &var = symbolNode->variable();
1605 if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
1606 mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
1607 {
1608 // Some variables, like shader inputs/outputs/builtins, are declared in the WGSL source
1609 // outside of the traverser.
1610 return false;
1611 }
1612 emitVariableDeclaration({var.symbolType(), var.name(), var.getType()}, evdConfig);
1613 }
1614 else if (TIntermBinary *initNode = node.getAsBinaryNode())
1615 {
1616 ASSERT(initNode->getOp() == TOperator::EOpInitialize);
1617 TIntermSymbol *leftSymbolNode = initNode->getLeft()->getAsSymbolNode();
1618 TIntermTyped *valueNode = initNode->getRight()->getAsTyped();
1619 ASSERT(leftSymbolNode && valueNode);
1620
1621 const TVariable &var = leftSymbolNode->variable();
1622 if (mRewritePipelineVarOutput->IsInputVar(var.uniqueId()) ||
1623 mRewritePipelineVarOutput->IsOutputVar(var.uniqueId()))
1624 {
1625 // Some variables, like shader inputs/outputs/builtins, are declared in the WGSL source
1626 // outside of the traverser.
1627 return false;
1628 }
1629
1630 emitVariableDeclaration({var.symbolType(), var.name(), var.getType()}, evdConfig);
1631 mSink << " = ";
1632 groupedTraverse(*valueNode);
1633 }
1634 else
1635 {
1636 UNREACHABLE();
1637 }
1638
1639 return false;
1640 }
1641
visitLoop(Visit,TIntermLoop * loopNode)1642 bool OutputWGSLTraverser::visitLoop(Visit, TIntermLoop *loopNode)
1643 {
1644 const TLoopType loopType = loopNode->getType();
1645
1646 switch (loopType)
1647 {
1648 case TLoopType::ELoopFor:
1649 return emitForLoop(loopNode);
1650 case TLoopType::ELoopWhile:
1651 return emitWhileLoop(loopNode);
1652 case TLoopType::ELoopDoWhile:
1653 return emulateDoWhileLoop(loopNode);
1654 }
1655 }
1656
emitForLoop(TIntermLoop * loopNode)1657 bool OutputWGSLTraverser::emitForLoop(TIntermLoop *loopNode)
1658 {
1659 ASSERT(loopNode->getType() == TLoopType::ELoopFor);
1660
1661 TIntermNode *initNode = loopNode->getInit();
1662 TIntermTyped *condNode = loopNode->getCondition();
1663 TIntermTyped *exprNode = loopNode->getExpression();
1664
1665 mSink << "for (";
1666
1667 if (initNode)
1668 {
1669 initNode->traverse(this);
1670 }
1671 else
1672 {
1673 mSink << " ";
1674 }
1675
1676 mSink << "; ";
1677
1678 if (condNode)
1679 {
1680 condNode->traverse(this);
1681 }
1682
1683 mSink << "; ";
1684
1685 if (exprNode)
1686 {
1687 exprNode->traverse(this);
1688 }
1689
1690 mSink << ")\n";
1691
1692 loopNode->getBody()->traverse(this);
1693
1694 return false;
1695 }
1696
emitWhileLoop(TIntermLoop * loopNode)1697 bool OutputWGSLTraverser::emitWhileLoop(TIntermLoop *loopNode)
1698 {
1699 ASSERT(loopNode->getType() == TLoopType::ELoopWhile);
1700
1701 TIntermNode *initNode = loopNode->getInit();
1702 TIntermTyped *condNode = loopNode->getCondition();
1703 TIntermTyped *exprNode = loopNode->getExpression();
1704 ASSERT(condNode);
1705 ASSERT(!initNode && !exprNode);
1706
1707 emitIndentation();
1708 mSink << "while (";
1709 condNode->traverse(this);
1710 mSink << ")\n";
1711 loopNode->getBody()->traverse(this);
1712
1713 return false;
1714 }
1715
emulateDoWhileLoop(TIntermLoop * loopNode)1716 bool OutputWGSLTraverser::emulateDoWhileLoop(TIntermLoop *loopNode)
1717 {
1718 ASSERT(loopNode->getType() == TLoopType::ELoopDoWhile);
1719
1720 TIntermNode *initNode = loopNode->getInit();
1721 TIntermTyped *condNode = loopNode->getCondition();
1722 TIntermTyped *exprNode = loopNode->getExpression();
1723 ASSERT(condNode);
1724 ASSERT(!initNode && !exprNode);
1725
1726 emitIndentation();
1727 // Write an infinite loop.
1728 mSink << "loop {\n";
1729 mIndentLevel++;
1730 loopNode->getBody()->traverse(this);
1731 mSink << "\n";
1732 emitIndentation();
1733 // At the end of the loop, break if the loop condition dos not still hold.
1734 mSink << "if (!(";
1735 condNode->traverse(this);
1736 mSink << ") { break; }\n";
1737 mIndentLevel--;
1738 emitIndentation();
1739 mSink << "}";
1740
1741 return false;
1742 }
1743
visitBranch(Visit,TIntermBranch * branchNode)1744 bool OutputWGSLTraverser::visitBranch(Visit, TIntermBranch *branchNode)
1745 {
1746 const TOperator flowOp = branchNode->getFlowOp();
1747 TIntermTyped *exprNode = branchNode->getExpression();
1748
1749 emitIndentation();
1750
1751 switch (flowOp)
1752 {
1753 case TOperator::EOpKill:
1754 {
1755 ASSERT(exprNode == nullptr);
1756 mSink << "discard";
1757 }
1758 break;
1759
1760 case TOperator::EOpReturn:
1761 {
1762 mSink << "return";
1763 if (exprNode)
1764 {
1765 mSink << " ";
1766 exprNode->traverse(this);
1767 }
1768 }
1769 break;
1770
1771 case TOperator::EOpBreak:
1772 {
1773 ASSERT(exprNode == nullptr);
1774 mSink << "break";
1775 }
1776 break;
1777
1778 case TOperator::EOpContinue:
1779 {
1780 ASSERT(exprNode == nullptr);
1781 mSink << "continue";
1782 }
1783 break;
1784
1785 default:
1786 {
1787 UNREACHABLE();
1788 }
1789 }
1790
1791 return false;
1792 }
1793
visitPreprocessorDirective(TIntermPreprocessorDirective * node)1794 void OutputWGSLTraverser::visitPreprocessorDirective(TIntermPreprocessorDirective *node)
1795 {
1796 // No preprocessor directives expected at this point.
1797 UNREACHABLE();
1798 }
1799
emitBareTypeName(const TType & type)1800 void OutputWGSLTraverser::emitBareTypeName(const TType &type)
1801 {
1802 WriteWgslBareTypeName(mSink, type);
1803 }
1804
emitType(const TType & type)1805 void OutputWGSLTraverser::emitType(const TType &type)
1806 {
1807 WriteWgslType(mSink, type);
1808 }
1809
1810 } // namespace
1811
TranslatorWGSL(sh::GLenum type,ShShaderSpec spec,ShShaderOutput output)1812 TranslatorWGSL::TranslatorWGSL(sh::GLenum type, ShShaderSpec spec, ShShaderOutput output)
1813 : TCompiler(type, spec, output)
1814 {}
1815
translate(TIntermBlock * root,const ShCompileOptions & compileOptions,PerformanceDiagnostics * perfDiagnostics)1816 bool TranslatorWGSL::translate(TIntermBlock *root,
1817 const ShCompileOptions &compileOptions,
1818 PerformanceDiagnostics *perfDiagnostics)
1819 {
1820 if (kOutputTreeBeforeTranslation)
1821 {
1822 OutputTree(root, getInfoSink().info);
1823 std::cout << getInfoSink().info.c_str();
1824 }
1825
1826 RewritePipelineVarOutput rewritePipelineVarOutput(getShaderType());
1827
1828 // WGSL's main() will need to take parameters or return values if any glsl (input/output)
1829 // builtin variables are used.
1830 if (!GenerateMainFunctionAndIOStructs(*this, *root, rewritePipelineVarOutput))
1831 {
1832 return false;
1833 }
1834
1835 TInfoSinkBase &sink = getInfoSink().obj;
1836 // Start writing the output structs that will be referred to by the `traverser`'s output.'
1837 if (!rewritePipelineVarOutput.OutputStructs(sink))
1838 {
1839 return false;
1840 }
1841
1842 if (!OutputUniformBlocks(this, root))
1843 {
1844 return false;
1845 }
1846
1847 UniformBlockMetadata uniformBlockMetadata;
1848 if (!RecordUniformBlockMetadata(root, uniformBlockMetadata))
1849 {
1850 return false;
1851 }
1852
1853 // Write the body of the WGSL including the GLSL main() function.
1854 OutputWGSLTraverser traverser(this, &rewritePipelineVarOutput, &uniformBlockMetadata);
1855 root->traverse(&traverser);
1856
1857 // Write the actual WGSL main function, wgslMain(), which calls the GLSL main function.
1858 if (!rewritePipelineVarOutput.OutputMainFunction(sink))
1859 {
1860 return false;
1861 }
1862
1863 if (kOutputTranslatedShader)
1864 {
1865 std::cout << sink.str();
1866 }
1867
1868 return true;
1869 }
1870
shouldFlattenPragmaStdglInvariantAll()1871 bool TranslatorWGSL::shouldFlattenPragmaStdglInvariantAll()
1872 {
1873 // Not neccesary for WGSL transformation.
1874 return false;
1875 }
1876 } // namespace sh
1877