1 /*
2 * Copyright 2022 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/codegen/SkSLWGSLCodeGenerator.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/base/SkStringView.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/SkSLAnalysis.h"
19 #include "src/sksl/SkSLBuiltinTypes.h"
20 #include "src/sksl/SkSLCompiler.h"
21 #include "src/sksl/SkSLConstantFolder.h"
22 #include "src/sksl/SkSLContext.h"
23 #include "src/sksl/SkSLDefines.h"
24 #include "src/sksl/SkSLErrorReporter.h"
25 #include "src/sksl/SkSLIntrinsicList.h"
26 #include "src/sksl/SkSLMemoryLayout.h"
27 #include "src/sksl/SkSLOperator.h"
28 #include "src/sksl/SkSLOutputStream.h"
29 #include "src/sksl/SkSLPosition.h"
30 #include "src/sksl/SkSLProgramSettings.h"
31 #include "src/sksl/SkSLString.h"
32 #include "src/sksl/SkSLStringStream.h"
33 #include "src/sksl/SkSLUtil.h"
34 #include "src/sksl/analysis/SkSLProgramUsage.h"
35 #include "src/sksl/analysis/SkSLProgramVisitor.h"
36 #include "src/sksl/codegen/SkSLCodeGenTypes.h"
37 #include "src/sksl/codegen/SkSLCodeGenerator.h"
38 #include "src/sksl/ir/SkSLBinaryExpression.h"
39 #include "src/sksl/ir/SkSLBlock.h"
40 #include "src/sksl/ir/SkSLConstructor.h"
41 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
42 #include "src/sksl/ir/SkSLConstructorCompound.h"
43 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
44 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
45 #include "src/sksl/ir/SkSLDoStatement.h"
46 #include "src/sksl/ir/SkSLExpression.h"
47 #include "src/sksl/ir/SkSLExpressionStatement.h"
48 #include "src/sksl/ir/SkSLFieldAccess.h"
49 #include "src/sksl/ir/SkSLFieldSymbol.h"
50 #include "src/sksl/ir/SkSLForStatement.h"
51 #include "src/sksl/ir/SkSLFunctionCall.h"
52 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
53 #include "src/sksl/ir/SkSLFunctionDefinition.h"
54 #include "src/sksl/ir/SkSLIRHelpers.h"
55 #include "src/sksl/ir/SkSLIRNode.h"
56 #include "src/sksl/ir/SkSLIfStatement.h"
57 #include "src/sksl/ir/SkSLIndexExpression.h"
58 #include "src/sksl/ir/SkSLInterfaceBlock.h"
59 #include "src/sksl/ir/SkSLLayout.h"
60 #include "src/sksl/ir/SkSLLiteral.h"
61 #include "src/sksl/ir/SkSLModifierFlags.h"
62 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
63 #include "src/sksl/ir/SkSLPostfixExpression.h"
64 #include "src/sksl/ir/SkSLPrefixExpression.h"
65 #include "src/sksl/ir/SkSLProgram.h"
66 #include "src/sksl/ir/SkSLProgramElement.h"
67 #include "src/sksl/ir/SkSLReturnStatement.h"
68 #include "src/sksl/ir/SkSLSetting.h"
69 #include "src/sksl/ir/SkSLStatement.h"
70 #include "src/sksl/ir/SkSLStructDefinition.h"
71 #include "src/sksl/ir/SkSLSwitchCase.h"
72 #include "src/sksl/ir/SkSLSwitchStatement.h"
73 #include "src/sksl/ir/SkSLSwizzle.h"
74 #include "src/sksl/ir/SkSLTernaryExpression.h"
75 #include "src/sksl/ir/SkSLType.h"
76 #include "src/sksl/ir/SkSLVarDeclarations.h"
77 #include "src/sksl/ir/SkSLVariable.h"
78 #include "src/sksl/ir/SkSLVariableReference.h"
79 #include "src/sksl/spirv.h"
80 #include "src/sksl/transform/SkSLTransform.h"
81
82 #include <algorithm>
83 #include <cstddef>
84 #include <cstdint>
85 #include <initializer_list>
86 #include <iterator>
87 #include <memory>
88 #include <optional>
89 #include <string>
90 #include <string_view>
91 #include <utility>
92
93 using namespace skia_private;
94
95 namespace {
96
97 // Represents a function's dependencies that are not accessible in global scope. For instance,
98 // pipeline stage input and output parameters must be passed in as an argument.
99 //
100 // This is a bitmask enum. (It would be inside `class WGSLCodeGenerator`, but this leads to build
101 // errors in MSVC.)
102 enum class WGSLFunctionDependency : uint8_t {
103 kNone = 0,
104 kPipelineInputs = 1 << 0,
105 kPipelineOutputs = 1 << 1,
106 };
107 using WGSLFunctionDependencies = SkEnumBitMask<WGSLFunctionDependency>;
108
109 SK_MAKE_BITMASK_OPS(WGSLFunctionDependency)
110
111 } // namespace
112
113 namespace SkSL {
114
115 class WGSLCodeGenerator : public CodeGenerator {
116 public:
117 // See https://www.w3.org/TR/WGSL/#builtin-values
118 enum class Builtin {
119 // Vertex stage:
120 kVertexIndex, // input
121 kInstanceIndex, // input
122 kPosition, // output, fragment stage input
123
124 // Fragment stage:
125 kLastFragColor, // input
126 kFrontFacing, // input
127 kSampleIndex, // input
128 kFragDepth, // output
129 kSampleMaskIn, // input
130 kSampleMask, // output
131
132 // Compute stage:
133 kLocalInvocationId, // input
134 kLocalInvocationIndex, // input
135 kGlobalInvocationId, // input
136 kWorkgroupId, // input
137 kNumWorkgroups, // input
138 };
139
140 // Variable declarations can be terminated by:
141 // - comma (","), e.g. in struct member declarations or function parameters
142 // - semicolon (";"), e.g. in function scope variables
143 // A "none" option is provided to skip the delimiter when not needed, e.g. at the end of a list
144 // of declarations.
145 enum class Delimiter {
146 kComma,
147 kSemicolon,
148 kNone,
149 };
150
151 struct ProgramRequirements {
152 using DepsMap = skia_private::THashMap<const FunctionDeclaration*,
153 WGSLFunctionDependencies>;
154
155 // Mappings used to synthesize function parameters according to dependencies on pipeline
156 // input/output variables.
157 DepsMap fDependencies;
158
159 // These flags track extensions that will need to be enabled.
160 bool fPixelLocalExtension = false;
161 };
162
WGSLCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out,PrettyPrint pp,IncludeSyntheticCode isc)163 WGSLCodeGenerator(const Context* context,
164 const ShaderCaps* caps,
165 const Program* program,
166 OutputStream* out,
167 PrettyPrint pp,
168 IncludeSyntheticCode isc)
169 : CodeGenerator(context, caps, program, out)
170 , fPrettyPrint(pp)
171 , fGenSyntheticCode(isc) {}
172
173 bool generateCode() override;
174
175 private:
176 using Precedence = OperatorPrecedence;
177
178 // Called by generateCode() as the first step.
179 void preprocessProgram();
180
181 // Write output content while correctly handling indentation.
182 void write(std::string_view s);
183 void writeLine(std::string_view s = std::string_view());
184 void finishLine();
185
186 // Helpers to declare a pipeline stage IO parameter declaration.
187 void writePipelineIODeclaration(const Layout& layout,
188 const Type& type,
189 ModifierFlags modifiers,
190 std::string_view name,
191 Delimiter delimiter);
192 void writeUserDefinedIODecl(const Layout& layout,
193 const Type& type,
194 ModifierFlags modifiers,
195 std::string_view name,
196 Delimiter delimiter);
197 void writeBuiltinIODecl(const Type& type,
198 std::string_view name,
199 Builtin builtin,
200 Delimiter delimiter);
201 void writeVariableDecl(const Layout& layout,
202 const Type& type,
203 std::string_view name,
204 Delimiter delimiter);
205
206 // Write a function definition.
207 void writeFunction(const FunctionDefinition& f);
208 void writeFunctionDeclaration(const FunctionDeclaration& f,
209 SkSpan<const bool> paramNeedsDedicatedStorage);
210
211 // Write the program entry point.
212 void writeEntryPoint(const FunctionDefinition& f);
213
214 // Writers for supported statement types.
215 void writeStatement(const Statement& s);
216 void writeStatements(const StatementArray& statements);
217 void writeBlock(const Block& b);
218 void writeDoStatement(const DoStatement& expr);
219 void writeExpressionStatement(const Expression& expr);
220 void writeForStatement(const ForStatement& s);
221 void writeIfStatement(const IfStatement& s);
222 void writeReturnStatement(const ReturnStatement& s);
223 void writeSwitchStatement(const SwitchStatement& s);
224 void writeSwitchCases(SkSpan<const SwitchCase* const> cases);
225 void writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase* const> cases,
226 std::string_view switchValue);
227 void writeSwitchCaseList(SkSpan<const SwitchCase* const> cases);
228 void writeVarDeclaration(const VarDeclaration& varDecl);
229
230 // Synthesizes an LValue for an expression.
231 class LValue;
232 class PointerLValue;
233 class SwizzleLValue;
234 class VectorComponentLValue;
235 std::unique_ptr<LValue> makeLValue(const Expression& e);
236
237 std::string variableReferenceNameForLValue(const VariableReference& r);
238 std::string variablePrefix(const Variable& v);
239
240 bool binaryOpNeedsComponentwiseMatrixPolyfill(const Type& left, const Type& right, Operator op);
241
242 // Writers for expressions. These return the final expression text as a string, and emit any
243 // necessary setup code directly into the program as necessary. The returned expression may be
244 // a `let`-alias that cannot be assigned-into; use `makeLValue` for an assignable expression.
245 std::string assembleExpression(const Expression& e, Precedence parentPrecedence);
246 std::string assembleBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
247 std::string assembleBinaryExpression(const Expression& left,
248 Operator op,
249 const Expression& right,
250 const Type& resultType,
251 Precedence parentPrecedence);
252 std::string assembleFieldAccess(const FieldAccess& f);
253 std::string assembleFunctionCall(const FunctionCall& call, Precedence parentPrecedence);
254 std::string assembleIndexExpression(const IndexExpression& i);
255 std::string assembleLiteral(const Literal& l);
256 std::string assemblePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
257 std::string assemblePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
258 std::string assembleSwizzle(const Swizzle& swizzle);
259 std::string assembleTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
260 std::string assembleVariableReference(const VariableReference& r);
261 std::string assembleName(std::string_view name);
262
263 std::string assembleIncrementExpr(const Type& type);
264
265 // Intrinsic helper functions.
266 std::string assembleIntrinsicCall(const FunctionCall& call,
267 IntrinsicKind kind,
268 Precedence parentPrecedence);
269 std::string assembleSimpleIntrinsic(std::string_view intrinsicName, const FunctionCall& call);
270 std::string assembleUnaryOpIntrinsic(Operator op,
271 const FunctionCall& call,
272 Precedence parentPrecedence);
273 std::string assembleBinaryOpIntrinsic(Operator op,
274 const FunctionCall& call,
275 Precedence parentPrecedence);
276 std::string assembleVectorizedIntrinsic(std::string_view intrinsicName,
277 const FunctionCall& call);
278 std::string assembleOutAssignedIntrinsic(std::string_view intrinsicName,
279 std::string_view returnFieldName,
280 std::string_view outFieldName,
281 const FunctionCall& call);
282 std::string assemblePartialSampleCall(std::string_view intrinsicName,
283 const Expression& sampler,
284 const Expression& coords);
285 std::string assembleInversePolyfill(const FunctionCall& call);
286 std::string assembleComponentwiseMatrixBinary(const Type& leftType,
287 const Type& rightType,
288 const std::string& left,
289 const std::string& right,
290 Operator op);
291
292 // Constructor expressions
293 std::string assembleAnyConstructor(const AnyConstructor& c);
294 std::string assembleConstructorCompound(const ConstructorCompound& c);
295 std::string assembleConstructorCompoundVector(const ConstructorCompound& c);
296 std::string assembleConstructorCompoundMatrix(const ConstructorCompound& c);
297 std::string assembleConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c);
298 std::string assembleConstructorMatrixResize(const ConstructorMatrixResize& ctor);
299
300 // Synthesized helper functions for comparison operators that are not supported by WGSL.
301 std::string assembleEqualityExpression(const Type& left,
302 const std::string& leftName,
303 const Type& right,
304 const std::string& rightName,
305 Operator op,
306 Precedence parentPrecedence);
307 std::string assembleEqualityExpression(const Expression& left,
308 const Expression& right,
309 Operator op,
310 Precedence parentPrecedence);
311
312 // Writes a scratch variable into the program and returns its name (e.g. `_skTemp123`).
313 std::string writeScratchVar(const Type& type, const std::string& value = "");
314
315 // Writes a scratch let-variable into the program, gives it the value of `expr`, and returns its
316 // name (e.g. `_skTemp123`).
317 std::string writeScratchLet(const std::string& expr, bool isCompileTimeConstant = false);
318 std::string writeScratchLet(const Expression& expr, Precedence parentPrecedence);
319
320 // Converts `expr` into a string and returns a scratch let-variable associated with the
321 // expression. Compile-time constants and plain variable references will return the expression
322 // directly and omit the let-variable.
323 std::string writeNontrivialScratchLet(const Expression& expr, Precedence parentPrecedence);
324
325 // Generic recursive ProgramElement visitor.
326 void writeProgramElement(const ProgramElement& e);
327 void writeGlobalVarDeclaration(const GlobalVarDeclaration& d);
328 void writeStructDefinition(const StructDefinition& s);
329 void writeModifiersDeclaration(const ModifiersDeclaration&);
330
331 // Writes the WGSL struct fields for SkSL structs and interface blocks. Enforces WGSL address
332 // space layout constraints
333 // (https://www.w3.org/TR/WGSL/#address-space-layout-constraints) if a `layout` is
334 // provided. A struct that does not need to be host-shareable does not require a `layout`.
335 void writeFields(SkSpan<const Field> fields, const MemoryLayout* memoryLayout = nullptr);
336
337 // We bundle uniforms, and all varying pipeline stage inputs and outputs, into separate structs.
338 bool needsStageInputStruct() const;
339 void writeStageInputStruct();
340 bool needsStageOutputStruct() const;
341 void writeStageOutputStruct();
342 void writeUniformsAndBuffers();
343 void prepareUniformPolyfillsForInterfaceBlock(const InterfaceBlock* interfaceBlock,
344 std::string_view instanceName,
345 MemoryLayout::Standard nativeLayout);
346 void writeEnables();
347 void writeUniformPolyfills();
348
349 void writeTextureOrSampler(const Variable& var,
350 int bindingLocation,
351 std::string_view suffix,
352 std::string_view wgslType);
353
354 // Writes all top-level non-opaque global uniform declarations (i.e. not part of an interface
355 // block) into a single uniform block binding.
356 //
357 // In complete fragment/vertex/compute programs, uniforms will be declared only as interface
358 // blocks and global opaque types (like textures and samplers) which we expect to be declared
359 // with a unique binding and descriptor set index. However, test files that are declared as RTE
360 // programs may contain OpenGL-style global uniform declarations with no clear binding index to
361 // use for the containing synthesized block.
362 //
363 // Since we are handling these variables only to generate gold files from RTEs and never run
364 // them, we always declare them at the default bind group and binding index.
365 void writeNonBlockUniformsForTests();
366
367 // For a given function declaration, writes out any implicitly required pipeline stage arguments
368 // based on the function's pre-determined dependencies. These are expected to be written out as
369 // the first parameters for a function that requires them. Returns true if any arguments were
370 // written.
371 std::string functionDependencyArgs(const FunctionDeclaration&);
372 bool writeFunctionDependencyParams(const FunctionDeclaration&);
373
374 // Code in the header appears before the main body of code.
375 StringStream fHeader;
376
377 // We assign unique names to anonymous interface blocks based on the type.
378 skia_private::THashMap<const Type*, std::string> fInterfaceBlockNameMap;
379
380 // Stores the functions which use stage inputs/outputs as well as required WGSL extensions.
381 ProgramRequirements fRequirements;
382 skia_private::TArray<const Variable*> fPipelineInputs;
383 skia_private::TArray<const Variable*> fPipelineOutputs;
384
385 // These fields track whether we have written the polyfill for `inverse()` for a given matrix
386 // type.
387 bool fWrittenInverse2 = false;
388 bool fWrittenInverse3 = false;
389 bool fWrittenInverse4 = false;
390 PrettyPrint fPrettyPrint;
391 IncludeSyntheticCode fGenSyntheticCode;
392
393 // These fields control uniform polyfill support in cases where WGSL and std140 disagree.
394 // In std140 layout, matrices need to be represented as arrays of @size(16)-aligned vectors, and
395 // array elements are wrapped in a struct containing a single @size(16)-aligned element. Arrays
396 // of matrices combine both wrappers. These wrapper structs are unpacked into natively-typed
397 // globals at the shader entrypoint.
398 struct FieldPolyfillInfo {
399 const InterfaceBlock* fInterfaceBlock;
400 std::string fReplacementName;
401 bool fIsArray = false;
402 bool fIsMatrix = false;
403 bool fWasAccessed = false;
404 };
405 using FieldPolyfillMap = skia_private::THashMap<const Field*, FieldPolyfillInfo>;
406 FieldPolyfillMap fFieldPolyfillMap;
407
408 // Output processing state.
409 int fIndentation = 0;
410 bool fAtLineStart = false;
411 bool fHasUnconditionalReturn = false;
412 bool fAtFunctionScope = false;
413 int fConditionalScopeDepth = 0;
414 int fLocalSizeX = 1;
415 int fLocalSizeY = 1;
416 int fLocalSizeZ = 1;
417
418 int fScratchCount = 0;
419 };
420
421 enum class ProgramKind : int8_t;
422
423 namespace {
424
425 static constexpr char kSamplerSuffix[] = "_Sampler";
426 static constexpr char kTextureSuffix[] = "_Texture";
427
428 // See https://www.w3.org/TR/WGSL/#memory-view-types
429 enum class PtrAddressSpace {
430 kFunction,
431 kPrivate,
432 kStorage,
433 };
434
operator_name(Operator op)435 const char* operator_name(Operator op) {
436 switch (op.kind()) {
437 case Operator::Kind::LOGICALXOR: return " != ";
438 default: return op.operatorName();
439 }
440 }
441
is_reserved_word(std::string_view word)442 bool is_reserved_word(std::string_view word) {
443 static const THashSet<std::string_view> kReservedWords{
444 // Used by SkSL:
445 "FSIn",
446 "FSOut",
447 "VSIn",
448 "VSOut",
449 "CSIn",
450 "_globalUniforms",
451 "_GlobalUniforms",
452 "_return",
453 "_stageIn",
454 "_stageOut",
455 // Keywords: https://www.w3.org/TR/WGSL/#keyword-summary
456 "alias",
457 "break",
458 "case",
459 "const",
460 "const_assert",
461 "continue",
462 "continuing",
463 "default",
464 "diagnostic",
465 "discard",
466 "else",
467 "enable",
468 "false",
469 "fn",
470 "for",
471 "if",
472 "let",
473 "loop",
474 "override",
475 "requires",
476 "return",
477 "struct",
478 "switch",
479 "true",
480 "var",
481 "while",
482 // Pre-declared types: https://www.w3.org/TR/WGSL/#predeclared-types
483 "bool",
484 "f16",
485 "f32",
486 "i32",
487 "u32",
488 // ... and pre-declared type generators:
489 "array",
490 "atomic",
491 "mat2x2",
492 "mat2x3",
493 "mat2x4",
494 "mat3x2",
495 "mat3x3",
496 "mat3x4",
497 "mat4x2",
498 "mat4x3",
499 "mat4x4",
500 "ptr",
501 "texture_1d",
502 "texture_2d",
503 "texture_2d_array",
504 "texture_3d",
505 "texture_cube",
506 "texture_cube_array",
507 "texture_multisampled_2d",
508 "texture_storage_1d",
509 "texture_storage_2d",
510 "texture_storage_2d_array",
511 "texture_storage_3d",
512 "vec2",
513 "vec3",
514 "vec4",
515 // Pre-declared enumerants: https://www.w3.org/TR/WGSL/#predeclared-enumerants
516 "read",
517 "write",
518 "read_write",
519 "function",
520 "private",
521 "workgroup",
522 "uniform",
523 "storage",
524 "perspective",
525 "linear",
526 "flat",
527 "center",
528 "centroid",
529 "sample",
530 "vertex_index",
531 "instance_index",
532 "position",
533 "front_facing",
534 "frag_depth",
535 "local_invocation_id",
536 "local_invocation_index",
537 "global_invocation_id",
538 "workgroup_id",
539 "num_workgroups",
540 "sample_index",
541 "sample_mask",
542 "rgba8unorm",
543 "rgba8snorm",
544 "rgba8uint",
545 "rgba8sint",
546 "rgba16uint",
547 "rgba16sint",
548 "rgba16float",
549 "r32uint",
550 "r32sint",
551 "r32float",
552 "rg32uint",
553 "rg32sint",
554 "rg32float",
555 "rgba32uint",
556 "rgba32sint",
557 "rgba32float",
558 "bgra8unorm",
559 // Reserved words: https://www.w3.org/TR/WGSL/#reserved-words
560 "_",
561 "NULL",
562 "Self",
563 "abstract",
564 "active",
565 "alignas",
566 "alignof",
567 "as",
568 "asm",
569 "asm_fragment",
570 "async",
571 "attribute",
572 "auto",
573 "await",
574 "become",
575 "binding_array",
576 "cast",
577 "catch",
578 "class",
579 "co_await",
580 "co_return",
581 "co_yield",
582 "coherent",
583 "column_major",
584 "common",
585 "compile",
586 "compile_fragment",
587 "concept",
588 "const_cast",
589 "consteval",
590 "constexpr",
591 "constinit",
592 "crate",
593 "debugger",
594 "decltype",
595 "delete",
596 "demote",
597 "demote_to_helper",
598 "do",
599 "dynamic_cast",
600 "enum",
601 "explicit",
602 "export",
603 "extends",
604 "extern",
605 "external",
606 "fallthrough",
607 "filter",
608 "final",
609 "finally",
610 "friend",
611 "from",
612 "fxgroup",
613 "get",
614 "goto",
615 "groupshared",
616 "highp",
617 "impl",
618 "implements",
619 "import",
620 "inline",
621 "instanceof",
622 "interface",
623 "layout",
624 "lowp",
625 "macro",
626 "macro_rules",
627 "match",
628 "mediump",
629 "meta",
630 "mod",
631 "module",
632 "move",
633 "mut",
634 "mutable",
635 "namespace",
636 "new",
637 "nil",
638 "noexcept",
639 "noinline",
640 "nointerpolation",
641 "noperspective",
642 "null",
643 "nullptr",
644 "of",
645 "operator",
646 "package",
647 "packoffset",
648 "partition",
649 "pass",
650 "patch",
651 "pixelfragment",
652 "precise",
653 "precision",
654 "premerge",
655 "priv",
656 "protected",
657 "pub",
658 "public",
659 "readonly",
660 "ref",
661 "regardless",
662 "register",
663 "reinterpret_cast",
664 "require",
665 "resource",
666 "restrict",
667 "self",
668 "set",
669 "shared",
670 "sizeof",
671 "smooth",
672 "snorm",
673 "static",
674 "static_assert",
675 "static_cast",
676 "std",
677 "subroutine",
678 "super",
679 "target",
680 "template",
681 "this",
682 "thread_local",
683 "throw",
684 "trait",
685 "try",
686 "type",
687 "typedef",
688 "typeid",
689 "typename",
690 "typeof",
691 "union",
692 "unless",
693 "unorm",
694 "unsafe",
695 "unsized",
696 "use",
697 "using",
698 "varying",
699 "virtual",
700 "volatile",
701 "wgsl",
702 "where",
703 "with",
704 "writeonly",
705 "yield",
706 };
707
708 return kReservedWords.contains(word);
709 }
710
pipeline_struct_prefix(ProgramKind kind)711 std::string_view pipeline_struct_prefix(ProgramKind kind) {
712 if (ProgramConfig::IsVertex(kind)) {
713 return "VS";
714 }
715 if (ProgramConfig::IsFragment(kind)) {
716 return "FS";
717 }
718 if (ProgramConfig::IsCompute(kind)) {
719 return "CS";
720 }
721 // Compute programs don't have stage-in/stage-out pipeline structs.
722 return "";
723 }
724
address_space_to_str(PtrAddressSpace addressSpace)725 std::string_view address_space_to_str(PtrAddressSpace addressSpace) {
726 switch (addressSpace) {
727 case PtrAddressSpace::kFunction:
728 return "function";
729 case PtrAddressSpace::kPrivate:
730 return "private";
731 case PtrAddressSpace::kStorage:
732 return "storage";
733 }
734 SkDEBUGFAIL("unsupported ptr address space");
735 return "unsupported";
736 }
737
to_scalar_type(const Type & type)738 std::string_view to_scalar_type(const Type& type) {
739 SkASSERT(type.typeKind() == Type::TypeKind::kScalar);
740 switch (type.numberKind()) {
741 // Floating-point numbers in WebGPU currently always have 32-bit footprint and
742 // relaxed-precision is not supported without extensions. f32 is the only floating-point
743 // number type in WGSL (see the discussion on https://github.com/gpuweb/gpuweb/issues/658).
744 case Type::NumberKind::kFloat:
745 return "f32";
746 case Type::NumberKind::kSigned:
747 return "i32";
748 case Type::NumberKind::kUnsigned:
749 return "u32";
750 case Type::NumberKind::kBoolean:
751 return "bool";
752 case Type::NumberKind::kNonnumeric:
753 [[fallthrough]];
754 default:
755 break;
756 }
757 return type.name();
758 }
759
760 // Convert a SkSL type to a WGSL type. Handles all plain types except structure types
761 // (see https://www.w3.org/TR/WGSL/#plain-types-section).
to_wgsl_type(const Context & context,const Type & raw,const Layout * layout=nullptr)762 std::string to_wgsl_type(const Context& context, const Type& raw, const Layout* layout = nullptr) {
763 const Type& type = raw.resolve().scalarTypeForLiteral();
764 switch (type.typeKind()) {
765 case Type::TypeKind::kScalar:
766 return std::string(to_scalar_type(type));
767
768 case Type::TypeKind::kAtomic:
769 SkASSERT(type.matches(*context.fTypes.fAtomicUInt));
770 return "atomic<u32>";
771
772 case Type::TypeKind::kVector: {
773 std::string_view ct = to_scalar_type(type.componentType());
774 return String::printf("vec%d<%.*s>", type.columns(), (int)ct.length(), ct.data());
775 }
776 case Type::TypeKind::kMatrix: {
777 std::string_view ct = to_scalar_type(type.componentType());
778 return String::printf("mat%dx%d<%.*s>",
779 type.columns(), type.rows(), (int)ct.length(), ct.data());
780 }
781 case Type::TypeKind::kArray: {
782 std::string result = "array<" + to_wgsl_type(context, type.componentType(), layout);
783 if (!type.isUnsizedArray()) {
784 result += ", ";
785 result += std::to_string(type.columns());
786 }
787 return result + '>';
788 }
789 case Type::TypeKind::kTexture: {
790 if (type.matches(*context.fTypes.fWriteOnlyTexture2D)) {
791 std::string result = "texture_storage_2d<";
792 // Write-only storage texture types require a pixel format, which is in the layout.
793 SkASSERT(layout);
794 LayoutFlags pixelFormat = layout->fFlags & LayoutFlag::kAllPixelFormats;
795 switch (pixelFormat.value()) {
796 case (int)LayoutFlag::kRGBA8:
797 return result + "rgba8unorm, write>";
798
799 case (int)LayoutFlag::kRGBA32F:
800 return result + "rgba32float, write>";
801
802 case (int)LayoutFlag::kR32F:
803 return result + "r32float, write>";
804
805 default:
806 // The front-end should have rejected this.
807 return result + "write>";
808 }
809 }
810 if (type.matches(*context.fTypes.fReadOnlyTexture2D)) {
811 return "texture_2d<f32>";
812 }
813 break;
814 }
815 default:
816 break;
817 }
818 return std::string(type.name());
819 }
820
to_ptr_type(const Context & context,const Type & type,const Layout * layout,PtrAddressSpace addressSpace=PtrAddressSpace::kFunction)821 std::string to_ptr_type(const Context& context,
822 const Type& type,
823 const Layout* layout,
824 PtrAddressSpace addressSpace = PtrAddressSpace::kFunction) {
825 return "ptr<" + std::string(address_space_to_str(addressSpace)) + ", " +
826 to_wgsl_type(context, type, layout) + '>';
827 }
828
wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin)829 std::string_view wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin) {
830 using Builtin = WGSLCodeGenerator::Builtin;
831 switch (builtin) {
832 case Builtin::kVertexIndex:
833 return "@builtin(vertex_index)";
834 case Builtin::kInstanceIndex:
835 return "@builtin(instance_index)";
836 case Builtin::kPosition:
837 return "@builtin(position)";
838 case Builtin::kLastFragColor:
839 return "@color(0)";
840 case Builtin::kFrontFacing:
841 return "@builtin(front_facing)";
842 case Builtin::kSampleIndex:
843 return "@builtin(sample_index)";
844 case Builtin::kFragDepth:
845 return "@builtin(frag_depth)";
846 case Builtin::kSampleMask:
847 case Builtin::kSampleMaskIn:
848 return "@builtin(sample_mask)";
849 case Builtin::kLocalInvocationId:
850 return "@builtin(local_invocation_id)";
851 case Builtin::kLocalInvocationIndex:
852 return "@builtin(local_invocation_index)";
853 case Builtin::kGlobalInvocationId:
854 return "@builtin(global_invocation_id)";
855 case Builtin::kWorkgroupId:
856 return "@builtin(workgroup_id)";
857 case Builtin::kNumWorkgroups:
858 return "@builtin(num_workgroups)";
859 default:
860 break;
861 }
862
863 SkDEBUGFAIL("unsupported builtin");
864 return "unsupported";
865 }
866
wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin)867 std::string_view wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin) {
868 using Builtin = WGSLCodeGenerator::Builtin;
869 switch (builtin) {
870 case Builtin::kVertexIndex:
871 return "u32";
872 case Builtin::kInstanceIndex:
873 return "u32";
874 case Builtin::kPosition:
875 return "vec4<f32>";
876 case Builtin::kLastFragColor:
877 return "vec4<f32>";
878 case Builtin::kFrontFacing:
879 return "bool";
880 case Builtin::kSampleIndex:
881 return "u32";
882 case Builtin::kFragDepth:
883 return "f32";
884 case Builtin::kSampleMask:
885 return "u32";
886 case Builtin::kSampleMaskIn:
887 return "u32";
888 case Builtin::kLocalInvocationId:
889 return "vec3<u32>";
890 case Builtin::kLocalInvocationIndex:
891 return "u32";
892 case Builtin::kGlobalInvocationId:
893 return "vec3<u32>";
894 case Builtin::kWorkgroupId:
895 return "vec3<u32>";
896 case Builtin::kNumWorkgroups:
897 return "vec3<u32>";
898 default:
899 break;
900 }
901
902 SkDEBUGFAIL("unsupported builtin");
903 return "unsupported";
904 }
905
906 // Some built-in variables have a type that differs from their SkSL counterpart (e.g. signed vs
907 // unsigned integer). We handle these cases with an explicit type conversion during a variable
908 // reference. Returns the WGSL type of the conversion target if conversion is needed, otherwise
909 // returns std::nullopt.
needs_builtin_type_conversion(const Variable & v)910 std::optional<std::string_view> needs_builtin_type_conversion(const Variable& v) {
911 switch (v.layout().fBuiltin) {
912 case SK_VERTEXID_BUILTIN:
913 case SK_INSTANCEID_BUILTIN:
914 return {"i32"};
915 default:
916 break;
917 }
918 return std::nullopt;
919 }
920
921 // Map a SkSL builtin flag to a WGSL builtin kind. Returns std::nullopt if `builtin` is not
922 // not supported for WGSL.
923 //
924 // Also see //src/sksl/sksl_vert.sksl and //src/sksl/sksl_frag.sksl for supported built-ins.
builtin_from_sksl_name(int builtin)925 std::optional<WGSLCodeGenerator::Builtin> builtin_from_sksl_name(int builtin) {
926 using Builtin = WGSLCodeGenerator::Builtin;
927 switch (builtin) {
928 case SK_POSITION_BUILTIN:
929 [[fallthrough]];
930 case SK_FRAGCOORD_BUILTIN:
931 return Builtin::kPosition;
932 case SK_VERTEXID_BUILTIN:
933 return Builtin::kVertexIndex;
934 case SK_INSTANCEID_BUILTIN:
935 return Builtin::kInstanceIndex;
936 case SK_LASTFRAGCOLOR_BUILTIN:
937 return Builtin::kLastFragColor;
938 case SK_CLOCKWISE_BUILTIN:
939 // TODO(skia:13092): While `front_facing` is the corresponding built-in, it does not
940 // imply a particular winding order. We correctly compute the face orientation based
941 // on how Skia configured the render pipeline for all references to this built-in
942 // variable (see `SkSL::Program::Interface::fRTFlipUniform`).
943 return Builtin::kFrontFacing;
944 case SK_SAMPLEMASKIN_BUILTIN:
945 return Builtin::kSampleMaskIn;
946 case SK_SAMPLEMASK_BUILTIN:
947 return Builtin::kSampleMask;
948 case SK_NUMWORKGROUPS_BUILTIN:
949 return Builtin::kNumWorkgroups;
950 case SK_WORKGROUPID_BUILTIN:
951 return Builtin::kWorkgroupId;
952 case SK_LOCALINVOCATIONID_BUILTIN:
953 return Builtin::kLocalInvocationId;
954 case SK_GLOBALINVOCATIONID_BUILTIN:
955 return Builtin::kGlobalInvocationId;
956 case SK_LOCALINVOCATIONINDEX_BUILTIN:
957 return Builtin::kLocalInvocationIndex;
958 default:
959 break;
960 }
961 return std::nullopt;
962 }
963
delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter)964 const char* delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter) {
965 using Delim = WGSLCodeGenerator::Delimiter;
966 switch (delimiter) {
967 case Delim::kComma:
968 return ",";
969 case Delim::kSemicolon:
970 return ";";
971 case Delim::kNone:
972 default:
973 break;
974 }
975 return "";
976 }
977
978 // FunctionDependencyResolver visits the IR tree rooted at a particular function definition and
979 // computes that function's dependencies on pipeline stage IO parameters. These are later used to
980 // synthesize arguments when writing out function definitions.
981 class FunctionDependencyResolver : public ProgramVisitor {
982 public:
983 using Deps = WGSLFunctionDependencies;
984 using DepsMap = WGSLCodeGenerator::ProgramRequirements::DepsMap;
985
FunctionDependencyResolver(const Program * p,const FunctionDeclaration * f,DepsMap * programDependencyMap)986 FunctionDependencyResolver(const Program* p,
987 const FunctionDeclaration* f,
988 DepsMap* programDependencyMap)
989 : fProgram(p), fFunction(f), fDependencyMap(programDependencyMap) {}
990
resolve()991 Deps resolve() {
992 fDeps = WGSLFunctionDependency::kNone;
993 this->visit(*fProgram);
994 return fDeps;
995 }
996
997 private:
visitProgramElement(const ProgramElement & p)998 bool visitProgramElement(const ProgramElement& p) override {
999 // Only visit the program that matches the requested function.
1000 if (p.is<FunctionDefinition>() && &p.as<FunctionDefinition>().declaration() == fFunction) {
1001 return ProgramVisitor::visitProgramElement(p);
1002 }
1003 // Continue visiting other program elements.
1004 return false;
1005 }
1006
visitExpression(const Expression & e)1007 bool visitExpression(const Expression& e) override {
1008 if (e.is<VariableReference>()) {
1009 const VariableReference& v = e.as<VariableReference>();
1010 if (v.variable()->storage() == Variable::Storage::kGlobal) {
1011 ModifierFlags flags = v.variable()->modifierFlags();
1012 if (flags & ModifierFlag::kIn) {
1013 fDeps |= WGSLFunctionDependency::kPipelineInputs;
1014 }
1015 if (flags & ModifierFlag::kOut) {
1016 fDeps |= WGSLFunctionDependency::kPipelineOutputs;
1017 }
1018 }
1019 } else if (e.is<FunctionCall>()) {
1020 // The current function that we're processing (`fFunction`) inherits the dependencies of
1021 // functions that it makes calls to, because the pipeline stage IO parameters need to be
1022 // passed down as an argument.
1023 const FunctionCall& callee = e.as<FunctionCall>();
1024
1025 // Don't process a function again if we have already resolved it.
1026 Deps* found = fDependencyMap->find(&callee.function());
1027 if (found) {
1028 fDeps |= *found;
1029 } else {
1030 // Store the dependencies that have been discovered for the current function so far.
1031 // If `callee` directly or indirectly calls the current function, then this value
1032 // will prevent an infinite recursion.
1033 fDependencyMap->set(fFunction, fDeps);
1034
1035 // Separately traverse the called function's definition and determine its
1036 // dependencies.
1037 FunctionDependencyResolver resolver(fProgram, &callee.function(), fDependencyMap);
1038 Deps calleeDeps = resolver.resolve();
1039
1040 // Store the callee's dependencies in the global map to avoid processing
1041 // the function again for future calls.
1042 fDependencyMap->set(&callee.function(), calleeDeps);
1043
1044 // Add to the current function's dependencies.
1045 fDeps |= calleeDeps;
1046 }
1047 }
1048 return ProgramVisitor::visitExpression(e);
1049 }
1050
1051 const Program* const fProgram;
1052 const FunctionDeclaration* const fFunction;
1053 DepsMap* const fDependencyMap;
1054 Deps fDeps = WGSLFunctionDependency::kNone;
1055 };
1056
resolve_program_requirements(const Program * program)1057 WGSLCodeGenerator::ProgramRequirements resolve_program_requirements(const Program* program) {
1058 WGSLCodeGenerator::ProgramRequirements requirements;
1059
1060 for (const ProgramElement* e : program->elements()) {
1061 switch (e->kind()) {
1062 case ProgramElement::Kind::kFunction: {
1063 const FunctionDeclaration& decl = e->as<FunctionDefinition>().declaration();
1064
1065 FunctionDependencyResolver resolver(program, &decl, &requirements.fDependencies);
1066 requirements.fDependencies.set(&decl, resolver.resolve());
1067 break;
1068 }
1069 case ProgramElement::Kind::kGlobalVar: {
1070 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
1071 if (decl.varDeclaration().var()->modifierFlags().isPixelLocal()) {
1072 requirements.fPixelLocalExtension = true;
1073 }
1074 break;
1075 }
1076 default:
1077 break;
1078 }
1079 }
1080
1081 return requirements;
1082 }
1083
collect_pipeline_io_vars(const Program * program,TArray<const Variable * > * ioVars,ModifierFlag ioType)1084 void collect_pipeline_io_vars(const Program* program,
1085 TArray<const Variable*>* ioVars,
1086 ModifierFlag ioType) {
1087 for (const ProgramElement* e : program->elements()) {
1088 if (e->is<GlobalVarDeclaration>()) {
1089 const Variable* v = e->as<GlobalVarDeclaration>().varDeclaration().var();
1090 if (v->modifierFlags() & ioType) {
1091 ioVars->push_back(v);
1092 }
1093 } else if (e->is<InterfaceBlock>()) {
1094 const Variable* v = e->as<InterfaceBlock>().var();
1095 if (v->modifierFlags() & ioType) {
1096 ioVars->push_back(v);
1097 }
1098 }
1099 }
1100 }
1101
is_in_global_uniforms(const Variable & var)1102 bool is_in_global_uniforms(const Variable& var) {
1103 SkASSERT(var.storage() == VariableStorage::kGlobal);
1104 return var.modifierFlags().isUniform() &&
1105 !var.type().isOpaque() &&
1106 !var.interfaceBlock();
1107 }
1108
1109 } // namespace
1110
1111 class WGSLCodeGenerator::LValue {
1112 public:
1113 virtual ~LValue() = default;
1114
1115 // Returns a WGSL expression that loads from the lvalue with no side effects.
1116 // (e.g. `array[index].field`)
1117 virtual std::string load() = 0;
1118
1119 // Returns a WGSL statement that stores into the lvalue with no side effects.
1120 // (e.g. `array[index].field = the_passed_in_value_string;`)
1121 virtual std::string store(const std::string& value) = 0;
1122 };
1123
1124 class WGSLCodeGenerator::PointerLValue : public WGSLCodeGenerator::LValue {
1125 public:
1126 // `name` must be a WGSL expression with no side-effects, which we can safely take the address
1127 // of. (e.g. `array[index].field` would be valid, but `array[Func()]` or `vector.x` are not.)
PointerLValue(std::string name)1128 PointerLValue(std::string name) : fName(std::move(name)) {}
1129
load()1130 std::string load() override {
1131 return fName;
1132 }
1133
store(const std::string & value)1134 std::string store(const std::string& value) override {
1135 return fName + " = " + value + ";";
1136 }
1137
1138 private:
1139 std::string fName;
1140 };
1141
1142 class WGSLCodeGenerator::VectorComponentLValue : public WGSLCodeGenerator::LValue {
1143 public:
1144 // `name` must be a WGSL expression with no side-effects that points to a single component of a
1145 // WGSL vector.
VectorComponentLValue(std::string name)1146 VectorComponentLValue(std::string name) : fName(std::move(name)) {}
1147
load()1148 std::string load() override {
1149 return fName;
1150 }
1151
store(const std::string & value)1152 std::string store(const std::string& value) override {
1153 return fName + " = " + value + ";";
1154 }
1155
1156 private:
1157 std::string fName;
1158 };
1159
1160 class WGSLCodeGenerator::SwizzleLValue : public WGSLCodeGenerator::LValue {
1161 public:
1162 // `name` must be a WGSL expression with no side-effects that points to a WGSL vector.
SwizzleLValue(const Context & ctx,std::string name,const Type & t,const ComponentArray & c)1163 SwizzleLValue(const Context& ctx, std::string name, const Type& t, const ComponentArray& c)
1164 : fContext(ctx)
1165 , fName(std::move(name))
1166 , fType(t)
1167 , fComponents(c) {
1168 // If the component array doesn't cover the entire value, we need to create masks for
1169 // writing back into the lvalue. For example, if the type is vec4 and the component array
1170 // holds `zx`, a GLSL assignment would look like:
1171 // name.zx = new_value;
1172 //
1173 // The equivalent WGSL assignment statement would look like:
1174 // name = vec4<f32>(new_value, name.xw).yzxw;
1175 //
1176 // This replaces name.zy with new_value.xy, and leaves name.xw at their original values.
1177 // By convention, we always put the new value first and the original values second; it might
1178 // be possible to find better arrangements which simplify the assignment overall, but we
1179 // don't attempt this.
1180 int fullSlotCount = fType.slotCount();
1181 SkASSERT(fullSlotCount <= 4);
1182
1183 // First, see which components are used.
1184 // The assignment swizzle must not reuse components.
1185 bool used[4] = {};
1186 for (int8_t component : fComponents) {
1187 SkASSERT(!used[component]);
1188 used[component] = true;
1189 }
1190
1191 // Any untouched components will need to be fetched from the original value.
1192 for (int index = 0; index < fullSlotCount; ++index) {
1193 if (!used[index]) {
1194 fUntouchedComponents.push_back(index);
1195 }
1196 }
1197
1198 // The reintegration swizzle needs to move the components back into their proper slots.
1199 fReintegrationSwizzle.resize(fullSlotCount);
1200 int reintegrateIndex = 0;
1201
1202 // This refills the untouched slots with the original values.
1203 auto refillUntouchedSlots = [&] {
1204 for (int index = 0; index < fullSlotCount; ++index) {
1205 if (!used[index]) {
1206 fReintegrationSwizzle[index] = reintegrateIndex++;
1207 }
1208 }
1209 };
1210
1211 // This places the new-value components into the proper slots.
1212 auto insertNewValuesIntoSlots = [&] {
1213 for (int index = 0; index < fComponents.size(); ++index) {
1214 fReintegrationSwizzle[fComponents[index]] = reintegrateIndex++;
1215 }
1216 };
1217
1218 // When reintegrating the untouched and new values, if the `x` slot is overwritten, we
1219 // reintegrate the new value first. Otherwise, we reintegrate the original value first.
1220 // This increases our odds of getting an identity swizzle for the reintegration.
1221 if (used[0]) {
1222 fReintegrateNewValueFirst = true;
1223 insertNewValuesIntoSlots();
1224 refillUntouchedSlots();
1225 } else {
1226 fReintegrateNewValueFirst = false;
1227 refillUntouchedSlots();
1228 insertNewValuesIntoSlots();
1229 }
1230 }
1231
load()1232 std::string load() override {
1233 return fName + "." + Swizzle::MaskString(fComponents);
1234 }
1235
store(const std::string & value)1236 std::string store(const std::string& value) override {
1237 // `variable = `
1238 std::string result = fName;
1239 result += " = ";
1240
1241 if (fUntouchedComponents.empty()) {
1242 // `(new_value);`
1243 result += '(';
1244 result += value;
1245 result += ")";
1246 } else if (fReintegrateNewValueFirst) {
1247 // `vec4<f32>((new_value), `
1248 result += to_wgsl_type(fContext, fType);
1249 result += "((";
1250 result += value;
1251 result += "), ";
1252
1253 // `variable.yz)`
1254 result += fName;
1255 result += '.';
1256 result += Swizzle::MaskString(fUntouchedComponents);
1257 result += ')';
1258 } else {
1259 // `vec4<f32>(variable.yz`
1260 result += to_wgsl_type(fContext, fType);
1261 result += '(';
1262 result += fName;
1263 result += '.';
1264 result += Swizzle::MaskString(fUntouchedComponents);
1265
1266 // `, (new_value))`
1267 result += ", (";
1268 result += value;
1269 result += "))";
1270 }
1271
1272 if (!Swizzle::IsIdentity(fReintegrationSwizzle)) {
1273 // `.wzyx`
1274 result += '.';
1275 result += Swizzle::MaskString(fReintegrationSwizzle);
1276 }
1277
1278 return result + ';';
1279 }
1280
1281 private:
1282 const Context& fContext;
1283 std::string fName;
1284 const Type& fType;
1285 ComponentArray fComponents;
1286 ComponentArray fUntouchedComponents;
1287 ComponentArray fReintegrationSwizzle;
1288 bool fReintegrateNewValueFirst = false;
1289 };
1290
generateCode()1291 bool WGSLCodeGenerator::generateCode() {
1292 // The resources of a WGSL program are structured in the following way:
1293 // - Stage attribute inputs and outputs are bundled inside synthetic structs called
1294 // VSIn/VSOut/FSIn/FSOut/CSIn.
1295 // - All uniform and storage type resources are declared in global scope.
1296 this->preprocessProgram();
1297
1298 {
1299 AutoOutputStream outputToHeader(this, &fHeader, &fIndentation);
1300 this->writeEnables();
1301 this->writeStageInputStruct();
1302 this->writeStageOutputStruct();
1303 this->writeUniformsAndBuffers();
1304 this->writeNonBlockUniformsForTests();
1305 }
1306 StringStream body;
1307 {
1308 // Emit the program body.
1309 AutoOutputStream outputToBody(this, &body, &fIndentation);
1310 const FunctionDefinition* mainFunc = nullptr;
1311 for (const ProgramElement* e : fProgram.elements()) {
1312 this->writeProgramElement(*e);
1313
1314 if (e->is<FunctionDefinition>()) {
1315 const FunctionDefinition& func = e->as<FunctionDefinition>();
1316 if (func.declaration().isMain()) {
1317 mainFunc = &func;
1318 }
1319 }
1320 }
1321
1322 // At the bottom of the program body, emit the entrypoint function.
1323 // The entrypoint relies on state that has been collected while we emitted the rest of the
1324 // program, so it's important to do it last to make sure we don't miss anything.
1325 if (mainFunc) {
1326 this->writeEntryPoint(*mainFunc);
1327 }
1328 }
1329
1330 write_stringstream(fHeader, *fOut);
1331 write_stringstream(body, *fOut);
1332
1333 this->writeUniformPolyfills();
1334
1335 return fContext.fErrors->errorCount() == 0;
1336 }
1337
writeUniformPolyfills()1338 void WGSLCodeGenerator::writeUniformPolyfills() {
1339 // If we didn't encounter any uniforms that need polyfilling, there is nothing to do.
1340 if (fFieldPolyfillMap.empty()) {
1341 return;
1342 }
1343
1344 // We store the list of polyfilled fields as pointers in a hash-map, so the order can be
1345 // inconsistent across runs. For determinism, we sort the polyfilled objects by name here.
1346 TArray<const FieldPolyfillMap::Pair*> orderedFields;
1347 orderedFields.reserve_exact(fFieldPolyfillMap.count());
1348
1349 fFieldPolyfillMap.foreach([&](const FieldPolyfillMap::Pair& pair) {
1350 orderedFields.push_back(&pair);
1351 });
1352
1353 std::sort(orderedFields.begin(),
1354 orderedFields.end(),
1355 [](const FieldPolyfillMap::Pair* a, const FieldPolyfillMap::Pair* b) {
1356 return a->second.fReplacementName < b->second.fReplacementName;
1357 });
1358
1359 THashSet<const Type*> writtenArrayElementPolyfill;
1360 bool writtenUniformMatrixPolyfill[5][5] = {}; // m[column][row] for each matrix type
1361 bool writtenUniformRowPolyfill[5] = {}; // for each matrix row-size
1362 bool anyFieldAccessed = false;
1363 for (const FieldPolyfillMap::Pair* pair : orderedFields) {
1364 const auto& [field, info] = *pair;
1365 const Type* fieldType = field->fType;
1366 const Layout* fieldLayout = &field->fLayout;
1367
1368 if (info.fIsArray) {
1369 fieldType = &fieldType->componentType();
1370 if (!writtenArrayElementPolyfill.contains(fieldType)) {
1371 writtenArrayElementPolyfill.add(fieldType);
1372 this->write("struct _skArrayElement_");
1373 this->write(fieldType->abbreviatedName());
1374 this->writeLine(" {");
1375
1376 if (info.fIsMatrix) {
1377 // Create a struct representing the array containing std140-padded matrices.
1378 this->write(" e : _skMatrix");
1379 this->write(std::to_string(fieldType->columns()));
1380 this->writeLine(std::to_string(fieldType->rows()));
1381 } else {
1382 // Create a struct representing the array with extra padding between elements.
1383 this->write(" @size(16) e : ");
1384 this->writeLine(to_wgsl_type(fContext, *fieldType, fieldLayout));
1385 }
1386 this->writeLine("};");
1387 }
1388 }
1389
1390 if (info.fIsMatrix) {
1391 // Create structs representing the matrix as an array of vectors, whether or not the
1392 // matrix is ever accessed by the SkSL. (The struct itself is mentioned in the list of
1393 // uniforms.)
1394 int c = fieldType->columns();
1395 int r = fieldType->rows();
1396 if (!writtenUniformRowPolyfill[r]) {
1397 writtenUniformRowPolyfill[r] = true;
1398
1399 this->write("struct _skRow");
1400 this->write(std::to_string(r));
1401 this->writeLine(" {");
1402 this->write(" @size(16) r : vec");
1403 this->write(std::to_string(r));
1404 this->write("<");
1405 this->write(to_wgsl_type(fContext, fieldType->componentType(), fieldLayout));
1406 this->writeLine(">");
1407 this->writeLine("};");
1408 }
1409
1410 if (!writtenUniformMatrixPolyfill[c][r]) {
1411 writtenUniformMatrixPolyfill[c][r] = true;
1412
1413 this->write("struct _skMatrix");
1414 this->write(std::to_string(c));
1415 this->write(std::to_string(r));
1416 this->writeLine(" {");
1417 this->write(" c : array<_skRow");
1418 this->write(std::to_string(r));
1419 this->write(", ");
1420 this->write(std::to_string(c));
1421 this->writeLine(">");
1422 this->writeLine("};");
1423 }
1424 }
1425
1426 // We create a polyfill variable only if the uniform was actually accessed.
1427 if (!info.fWasAccessed) {
1428 continue;
1429 }
1430 anyFieldAccessed = true;
1431 this->write("var<private> ");
1432 this->write(info.fReplacementName);
1433 this->write(": ");
1434
1435 const Type& interfaceBlockType = info.fInterfaceBlock->var()->type();
1436 if (interfaceBlockType.isArray()) {
1437 this->write("array<");
1438 this->write(to_wgsl_type(fContext, *field->fType, fieldLayout));
1439 this->write(", ");
1440 this->write(std::to_string(interfaceBlockType.columns()));
1441 this->write(">");
1442 } else {
1443 this->write(to_wgsl_type(fContext, *field->fType, fieldLayout));
1444 }
1445 this->writeLine(";");
1446 }
1447
1448 // If no fields were actually accessed, _skInitializePolyfilledUniforms will not be called and
1449 // we can avoid emitting an empty, dead function.
1450 if (!anyFieldAccessed) {
1451 return;
1452 }
1453
1454 this->writeLine("fn _skInitializePolyfilledUniforms() {");
1455 ++fIndentation;
1456
1457 for (const FieldPolyfillMap::Pair* pair : orderedFields) {
1458 // Only initialize a polyfill global if the uniform was actually accessed.
1459 const auto& [field, info] = *pair;
1460 if (!info.fWasAccessed) {
1461 continue;
1462 }
1463
1464 // Synthesize the name of this uniform variable
1465 std::string_view instanceName = info.fInterfaceBlock->instanceName();
1466 const Type& interfaceBlockType = info.fInterfaceBlock->var()->type();
1467 if (instanceName.empty()) {
1468 instanceName = fInterfaceBlockNameMap[&interfaceBlockType.componentType()];
1469 }
1470
1471 // Initialize the global variable associated with this uniform.
1472 // If the interface block is arrayed, the associated global will be arrayed as well.
1473 int numIBElements = interfaceBlockType.isArray() ? interfaceBlockType.columns() : 1;
1474 for (int ibIdx = 0; ibIdx < numIBElements; ++ibIdx) {
1475 this->write(info.fReplacementName);
1476 if (interfaceBlockType.isArray()) {
1477 this->write("[");
1478 this->write(std::to_string(ibIdx));
1479 this->write("]");
1480 }
1481 this->write(" = ");
1482
1483 const Type* fieldType = field->fType;
1484 const Layout* fieldLayout = &field->fLayout;
1485
1486 int numArrayElements;
1487 if (info.fIsArray) {
1488 this->write(to_wgsl_type(fContext, *fieldType, fieldLayout));
1489 this->write("(");
1490 numArrayElements = fieldType->columns();
1491 fieldType = &fieldType->componentType();
1492 } else {
1493 numArrayElements = 1;
1494 }
1495
1496 auto arraySeparator = String::Separator();
1497 for (int arrayIdx = 0; arrayIdx < numArrayElements; arrayIdx++) {
1498 this->write(arraySeparator());
1499
1500 std::string fieldName{instanceName};
1501 if (interfaceBlockType.isArray()) {
1502 fieldName += '[';
1503 fieldName += std::to_string(ibIdx);
1504 fieldName += ']';
1505 }
1506 fieldName += '.';
1507 fieldName += this->assembleName(field->fName);
1508
1509 if (info.fIsArray) {
1510 fieldName += '[';
1511 fieldName += std::to_string(arrayIdx);
1512 fieldName += "].e";
1513 }
1514
1515 if (info.fIsMatrix) {
1516 this->write(to_wgsl_type(fContext, *fieldType, fieldLayout));
1517 this->write("(");
1518 int numColumns = fieldType->columns();
1519 auto matrixSeparator = String::Separator();
1520 for (int column = 0; column < numColumns; column++) {
1521 this->write(matrixSeparator());
1522 this->write(fieldName);
1523 this->write(".c[");
1524 this->write(std::to_string(column));
1525 this->write("].r");
1526 }
1527 this->write(")");
1528 } else {
1529 this->write(fieldName);
1530 }
1531 }
1532
1533 if (info.fIsArray) {
1534 this->write(")");
1535 }
1536
1537 this->writeLine(";");
1538 }
1539 }
1540
1541 --fIndentation;
1542 this->writeLine("}");
1543 }
1544
1545
preprocessProgram()1546 void WGSLCodeGenerator::preprocessProgram() {
1547 fRequirements = resolve_program_requirements(&fProgram);
1548 collect_pipeline_io_vars(&fProgram, &fPipelineInputs, ModifierFlag::kIn);
1549 collect_pipeline_io_vars(&fProgram, &fPipelineOutputs, ModifierFlag::kOut);
1550 }
1551
write(std::string_view s)1552 void WGSLCodeGenerator::write(std::string_view s) {
1553 if (s.empty()) {
1554 return;
1555 }
1556 if (fAtLineStart && fPrettyPrint == PrettyPrint::kYes) {
1557 for (int i = 0; i < fIndentation; i++) {
1558 fOut->writeText(" ");
1559 }
1560 }
1561 fOut->writeText(std::string(s).c_str());
1562 fAtLineStart = false;
1563 }
1564
writeLine(std::string_view s)1565 void WGSLCodeGenerator::writeLine(std::string_view s) {
1566 this->write(s);
1567 fOut->writeText("\n");
1568 fAtLineStart = true;
1569 }
1570
finishLine()1571 void WGSLCodeGenerator::finishLine() {
1572 if (!fAtLineStart) {
1573 this->writeLine();
1574 }
1575 }
1576
assembleName(std::string_view name)1577 std::string WGSLCodeGenerator::assembleName(std::string_view name) {
1578 if (name.empty()) {
1579 // WGSL doesn't allow anonymous function parameters.
1580 return "_skAnonymous" + std::to_string(fScratchCount++);
1581 }
1582 // Add `R_` before reserved names to avoid any potential reserved-word conflict.
1583 return (skstd::starts_with(name, "_sk") ||
1584 skstd::starts_with(name, "R_") ||
1585 is_reserved_word(name))
1586 ? std::string("R_") + std::string(name)
1587 : std::string(name);
1588 }
1589
writeVariableDecl(const Layout & layout,const Type & type,std::string_view name,Delimiter delimiter)1590 void WGSLCodeGenerator::writeVariableDecl(const Layout& layout,
1591 const Type& type,
1592 std::string_view name,
1593 Delimiter delimiter) {
1594 this->write(this->assembleName(name));
1595 this->write(": " + to_wgsl_type(fContext, type, &layout));
1596 this->writeLine(delimiter_to_str(delimiter));
1597 }
1598
writePipelineIODeclaration(const Layout & layout,const Type & type,ModifierFlags modifiers,std::string_view name,Delimiter delimiter)1599 void WGSLCodeGenerator::writePipelineIODeclaration(const Layout& layout,
1600 const Type& type,
1601 ModifierFlags modifiers,
1602 std::string_view name,
1603 Delimiter delimiter) {
1604 // In WGSL, an entry-point IO parameter is "one of either a built-in value or assigned a
1605 // location". However, some SkSL declarations, specifically sk_FragColor, can contain both a
1606 // location and a builtin modifier. In addition, WGSL doesn't have a built-in equivalent for
1607 // sk_FragColor as it relies on the user-defined location for a render target.
1608 //
1609 // Instead of special-casing sk_FragColor, we just give higher precedence to a location modifier
1610 // if a declaration happens to both have a location and it's a built-in.
1611 //
1612 // Also see:
1613 // https://www.w3.org/TR/WGSL/#input-output-locations
1614 // https://www.w3.org/TR/WGSL/#attribute-location
1615 // https://www.w3.org/TR/WGSL/#builtin-inputs-outputs
1616 if (layout.fLocation >= 0) {
1617 this->writeUserDefinedIODecl(layout, type, modifiers, name, delimiter);
1618 return;
1619 }
1620 if (layout.fBuiltin >= 0) {
1621 if (layout.fBuiltin == SK_POINTSIZE_BUILTIN) {
1622 // WebGPU does not support the point-size builtin, but we silently replace it with a
1623 // global variable when it is used, instead of reporting an error.
1624 return;
1625 }
1626 auto builtin = builtin_from_sksl_name(layout.fBuiltin);
1627 if (builtin.has_value()) {
1628 // Builtin IO parameters should only have in/out modifiers, which are then implicit in
1629 // the generated WGSL, hence why writeBuiltinIODecl does not need them passed in.
1630 SkASSERT(!(modifiers & ~(ModifierFlag::kIn | ModifierFlag::kOut)));
1631 this->writeBuiltinIODecl(type, name, *builtin, delimiter);
1632 return;
1633 }
1634 }
1635 fContext.fErrors->error(Position(), "declaration '" + std::string(name) + "' is not supported");
1636 }
1637
writeUserDefinedIODecl(const Layout & layout,const Type & type,ModifierFlags flags,std::string_view name,Delimiter delimiter)1638 void WGSLCodeGenerator::writeUserDefinedIODecl(const Layout& layout,
1639 const Type& type,
1640 ModifierFlags flags,
1641 std::string_view name,
1642 Delimiter delimiter) {
1643 this->write("@location(" + std::to_string(layout.fLocation) + ") ");
1644
1645 // @blend_src is only allowed when doing dual-source blending, and only on color attachment 0.
1646 if (layout.fLocation == 0 && layout.fIndex >= 0 && fProgram.fInterface.fOutputSecondaryColor) {
1647 this->write("@blend_src(" + std::to_string(layout.fIndex) + ") ");
1648 }
1649
1650 // "User-defined IO of scalar or vector integer type must always be specified as
1651 // @interpolate(flat)" (see https://www.w3.org/TR/WGSL/#interpolation)
1652 if (flags.isFlat() || type.isInteger() ||
1653 (type.isVector() && type.componentType().isInteger())) {
1654 // We can use 'either' to hint to WebGPU that we don't care about the provoking vertex and
1655 // avoid any expensive shader or data rewriting to ensure 'first'. Skia has a long-standing
1656 // policy to only use flat shading when it's constant for a primitive so the vertex doesn't
1657 // matter. See https://www.w3.org/TR/WGSL/#interpolation-sampling-either
1658 this->write("@interpolate(flat, either) ");
1659 } else if (flags & ModifierFlag::kNoPerspective) {
1660 this->write("@interpolate(linear) ");
1661 }
1662
1663 this->writeVariableDecl(layout, type, name, delimiter);
1664 }
1665
writeBuiltinIODecl(const Type & type,std::string_view name,Builtin builtin,Delimiter delimiter)1666 void WGSLCodeGenerator::writeBuiltinIODecl(const Type& type,
1667 std::string_view name,
1668 Builtin builtin,
1669 Delimiter delimiter) {
1670 this->write(wgsl_builtin_name(builtin));
1671 this->write(" ");
1672 this->write(this->assembleName(name));
1673 this->write(": ");
1674 this->write(wgsl_builtin_type(builtin));
1675 this->writeLine(delimiter_to_str(delimiter));
1676 }
1677
writeFunction(const FunctionDefinition & f)1678 void WGSLCodeGenerator::writeFunction(const FunctionDefinition& f) {
1679 const FunctionDeclaration& decl = f.declaration();
1680 fHasUnconditionalReturn = false;
1681 fConditionalScopeDepth = 0;
1682
1683 SkASSERT(!fAtFunctionScope);
1684 fAtFunctionScope = true;
1685
1686 // WGSL parameters are immutable and are considered as taking no storage, but SkSL parameters
1687 // are real variables. To work around this, we make var-based copies of parameters. It's
1688 // wasteful to make a copy of every single parameter--even if the compiler can eventually
1689 // optimize them all away, that takes time and generates bloated code. So, we only make
1690 // parameter copies if the variable is actually written-to.
1691 STArray<32, bool> paramNeedsDedicatedStorage;
1692 paramNeedsDedicatedStorage.push_back_n(decl.parameters().size(), true);
1693
1694 for (size_t index = 0; index < decl.parameters().size(); ++index) {
1695 const Variable& param = *decl.parameters()[index];
1696 if (param.type().isOpaque() || param.name().empty()) {
1697 // Opaque-typed or anonymous parameters don't need dedicated storage.
1698 paramNeedsDedicatedStorage[index] = false;
1699 continue;
1700 }
1701
1702 const ProgramUsage::VariableCounts counts = fProgram.fUsage->get(param);
1703 if ((param.modifierFlags() & ModifierFlag::kOut) || counts.fWrite == 0) {
1704 // Variables which are never written-to don't need dedicated storage.
1705 // Out-parameters are passed as pointers; the pointer itself is never modified, so
1706 // it doesn't need dedicated storage.
1707 paramNeedsDedicatedStorage[index] = false;
1708 }
1709 }
1710
1711 this->writeFunctionDeclaration(decl, paramNeedsDedicatedStorage);
1712 this->writeLine(" {");
1713 ++fIndentation;
1714
1715 // The parameters were given generic names like `_skParam1`, because WGSL parameters don't have
1716 // storage and are immutable. If mutability is required, we create variables here; otherwise, we
1717 // create properly-named `let` aliases.
1718 for (size_t index = 0; index < decl.parameters().size(); ++index) {
1719 if (paramNeedsDedicatedStorage[index]) {
1720 const Variable& param = *decl.parameters()[index];
1721 this->write("var ");
1722 this->write(this->assembleName(param.mangledName()));
1723 this->write(" = _skParam");
1724 this->write(std::to_string(index));
1725 this->writeLine(";");
1726 }
1727 }
1728
1729 this->writeBlock(f.body()->as<Block>());
1730
1731 // If fConditionalScopeDepth isn't zero, we have an unbalanced +1 or -1 when updating the depth.
1732 SkASSERT(fConditionalScopeDepth == 0);
1733 if (!fHasUnconditionalReturn && !f.declaration().returnType().isVoid()) {
1734 this->write("return ");
1735 this->write(to_wgsl_type(fContext, f.declaration().returnType()));
1736 this->writeLine("();");
1737 }
1738
1739 --fIndentation;
1740 this->writeLine("}");
1741
1742 SkASSERT(fAtFunctionScope);
1743 fAtFunctionScope = false;
1744 }
1745
writeFunctionDeclaration(const FunctionDeclaration & decl,SkSpan<const bool> paramNeedsDedicatedStorage)1746 void WGSLCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& decl,
1747 SkSpan<const bool> paramNeedsDedicatedStorage) {
1748 this->write("fn ");
1749 if (decl.isMain()) {
1750 this->write("_skslMain(");
1751 } else {
1752 this->write(this->assembleName(decl.mangledName()));
1753 this->write("(");
1754 }
1755 auto separator = SkSL::String::Separator();
1756 if (this->writeFunctionDependencyParams(decl)) {
1757 separator(); // update the separator as parameters have been written
1758 }
1759 for (size_t index = 0; index < decl.parameters().size(); ++index) {
1760 this->write(separator());
1761
1762 const Variable& param = *decl.parameters()[index];
1763 if (param.type().isOpaque()) {
1764 SkASSERT(!paramNeedsDedicatedStorage[index]);
1765 if (param.type().isSampler()) {
1766 // Create parameters for both the texture and associated sampler.
1767 this->write(param.name());
1768 this->write(kTextureSuffix);
1769 this->write(": texture_2d<f32>, ");
1770 this->write(param.name());
1771 this->write(kSamplerSuffix);
1772 this->write(": sampler");
1773 } else {
1774 // Create a parameter for the opaque object.
1775 this->write(param.name());
1776 this->write(": ");
1777 this->write(to_wgsl_type(fContext, param.type(), ¶m.layout()));
1778 }
1779 } else {
1780 if (paramNeedsDedicatedStorage[index] || param.name().empty()) {
1781 // Create an unnamed parameter. If the parameter needs dedicated storage, it will
1782 // later be assigned a `var` in the function body. (If it's anonymous, a var isn't
1783 // needed.)
1784 this->write("_skParam");
1785 this->write(std::to_string(index));
1786 } else {
1787 // Use the name directly from the SkSL program.
1788 this->write(this->assembleName(param.name()));
1789 }
1790 this->write(": ");
1791 if (param.type().isUnsizedArray()) {
1792 // Creates a storage address space pointer for unsized array parameters.
1793 // The buffer the array resides in must be marked `readonly` to have the array
1794 // be used in function parameters, since access modes in wgsl must exactly match.
1795 this->write("ptr<storage, ");
1796 this->write(to_wgsl_type(fContext, param.type(), ¶m.layout()));
1797 this->write(", read>");
1798 } else if (param.modifierFlags() & ModifierFlag::kOut) {
1799 // Declare an "out" function parameter as a pointer.
1800 this->write(to_ptr_type(fContext, param.type(), ¶m.layout()));
1801 } else {
1802 this->write(to_wgsl_type(fContext, param.type(), ¶m.layout()));
1803 }
1804 }
1805 }
1806 this->write(")");
1807 if (!decl.returnType().isVoid()) {
1808 this->write(" -> ");
1809 this->write(to_wgsl_type(fContext, decl.returnType()));
1810 }
1811 }
1812
writeEntryPoint(const FunctionDefinition & main)1813 void WGSLCodeGenerator::writeEntryPoint(const FunctionDefinition& main) {
1814 SkASSERT(main.declaration().isMain());
1815 const ProgramKind programKind = fProgram.fConfig->fKind;
1816
1817 if (fGenSyntheticCode == IncludeSyntheticCode::kYes &&
1818 ProgramConfig::IsRuntimeShader(programKind)) {
1819 // Synthesize a basic entrypoint which just calls straight through to main.
1820 // This is only used by skslc and just needs to pass the WGSL validator; Skia won't ever
1821 // emit functions like this.
1822 this->writeLine("@fragment fn main(@location(0) _coords: vec2<f32>) -> "
1823 "@location(0) vec4<f32> {");
1824 ++fIndentation;
1825 this->writeLine("return _skslMain(_coords);");
1826 --fIndentation;
1827 this->writeLine("}");
1828 return;
1829 }
1830
1831 // The input and output parameters for a vertex/fragment stage entry point function have the
1832 // FSIn/FSOut/VSIn/VSOut/CSIn struct types that have been synthesized in generateCode(). An
1833 // entrypoint always has a predictable signature and acts as a trampoline to the user-defined
1834 // main function.
1835 if (ProgramConfig::IsVertex(programKind)) {
1836 this->write("@vertex");
1837 } else if (ProgramConfig::IsFragment(programKind)) {
1838 this->write("@fragment");
1839 } else if (ProgramConfig::IsCompute(programKind)) {
1840 this->write("@compute @workgroup_size(");
1841 this->write(std::to_string(fLocalSizeX));
1842 this->write(", ");
1843 this->write(std::to_string(fLocalSizeY));
1844 this->write(", ");
1845 this->write(std::to_string(fLocalSizeZ));
1846 this->write(")");
1847 } else {
1848 fContext.fErrors->error(Position(), "program kind not supported");
1849 return;
1850 }
1851
1852 this->write(" fn main(");
1853 // The stage input struct is a parameter passed to main().
1854 if (this->needsStageInputStruct()) {
1855 this->write("_stageIn: ");
1856 this->write(pipeline_struct_prefix(programKind));
1857 this->write("In");
1858 }
1859 // The stage output struct is returned from main().
1860 if (this->needsStageOutputStruct()) {
1861 this->write(") -> ");
1862 this->write(pipeline_struct_prefix(programKind));
1863 this->writeLine("Out {");
1864 } else {
1865 this->writeLine(") {");
1866 }
1867 // Initialize polyfilled matrix uniforms if any were used.
1868 fIndentation++;
1869 for (const auto& [field, info] : fFieldPolyfillMap) {
1870 if (info.fWasAccessed) {
1871 this->writeLine("_skInitializePolyfilledUniforms();");
1872 break;
1873 }
1874 }
1875 // Declare the stage output struct.
1876 if (this->needsStageOutputStruct()) {
1877 this->write("var _stageOut: ");
1878 this->write(pipeline_struct_prefix(programKind));
1879 this->writeLine("Out;");
1880 }
1881
1882 // We are compiling a Runtime Effect as a fragment shader, for testing purposes. We assign the
1883 // result from _skslMain into sk_FragColor if the user-defined main returns a color. This
1884 // doesn't actually matter, but it is more indicative of what a real program would do.
1885 // `addImplicitFragColorWrite` from Transform::FindAndDeclareBuiltinVariables has already
1886 // injected sk_FragColor into our stage outputs even if it wasn't explicitly referenced.
1887 if (fGenSyntheticCode == IncludeSyntheticCode::kYes && ProgramConfig::IsFragment(programKind)) {
1888 if (main.declaration().returnType().matches(*fContext.fTypes.fHalf4)) {
1889 this->write("_stageOut.sk_FragColor = ");
1890 }
1891 }
1892
1893 // Generate a function call to the user-defined main.
1894 this->write("_skslMain(");
1895 auto separator = SkSL::String::Separator();
1896 WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&main.declaration());
1897 if (deps) {
1898 if (*deps & WGSLFunctionDependency::kPipelineInputs) {
1899 this->write(separator());
1900 this->write("_stageIn");
1901 }
1902 if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
1903 this->write(separator());
1904 this->write("&_stageOut");
1905 }
1906 }
1907
1908 if (fGenSyntheticCode == IncludeSyntheticCode::kYes) {
1909 if (const Variable* v = main.declaration().getMainCoordsParameter()) {
1910 // We are compiling a Runtime Effect as a fragment shader, for testing purposes.
1911 // We need to synthesize a coordinates parameter, but the coordinates don't matter.
1912 SkASSERT(ProgramConfig::IsFragment(programKind));
1913 const Type& type = v->type();
1914 if (!type.matches(*fContext.fTypes.fFloat2)) {
1915 fContext.fErrors->error(
1916 main.fPosition,
1917 "main function has unsupported parameter: " + type.description());
1918 return;
1919 }
1920 this->write(separator());
1921 this->write("/*fragcoord*/ vec2<f32>()");
1922 }
1923 }
1924
1925 this->writeLine(");");
1926
1927 if (this->needsStageOutputStruct()) {
1928 // Return the stage output struct.
1929 this->writeLine("return _stageOut;");
1930 }
1931
1932 fIndentation--;
1933 this->writeLine("}");
1934 }
1935
writeStatement(const Statement & s)1936 void WGSLCodeGenerator::writeStatement(const Statement& s) {
1937 switch (s.kind()) {
1938 case Statement::Kind::kBlock:
1939 this->writeBlock(s.as<Block>());
1940 break;
1941 case Statement::Kind::kBreak:
1942 this->writeLine("break;");
1943 break;
1944 case Statement::Kind::kContinue:
1945 this->writeLine("continue;");
1946 break;
1947 case Statement::Kind::kDiscard:
1948 this->writeLine("discard;");
1949 break;
1950 case Statement::Kind::kDo:
1951 this->writeDoStatement(s.as<DoStatement>());
1952 break;
1953 case Statement::Kind::kExpression:
1954 this->writeExpressionStatement(*s.as<ExpressionStatement>().expression());
1955 break;
1956 case Statement::Kind::kFor:
1957 this->writeForStatement(s.as<ForStatement>());
1958 break;
1959 case Statement::Kind::kIf:
1960 this->writeIfStatement(s.as<IfStatement>());
1961 break;
1962 case Statement::Kind::kNop:
1963 this->writeLine(";");
1964 break;
1965 case Statement::Kind::kReturn:
1966 this->writeReturnStatement(s.as<ReturnStatement>());
1967 break;
1968 case Statement::Kind::kSwitch:
1969 this->writeSwitchStatement(s.as<SwitchStatement>());
1970 break;
1971 case Statement::Kind::kSwitchCase:
1972 SkDEBUGFAIL("switch-case statements should only be present inside a switch");
1973 break;
1974 case Statement::Kind::kVarDeclaration:
1975 this->writeVarDeclaration(s.as<VarDeclaration>());
1976 break;
1977 }
1978 }
1979
writeStatements(const StatementArray & statements)1980 void WGSLCodeGenerator::writeStatements(const StatementArray& statements) {
1981 for (const auto& s : statements) {
1982 if (!s->isEmpty()) {
1983 this->writeStatement(*s);
1984 this->finishLine();
1985 }
1986 }
1987 }
1988
writeBlock(const Block & b)1989 void WGSLCodeGenerator::writeBlock(const Block& b) {
1990 // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
1991 // something here to make the code valid).
1992 bool isScope = b.isScope() || b.isEmpty();
1993 if (isScope) {
1994 this->writeLine("{");
1995 fIndentation++;
1996 }
1997 this->writeStatements(b.children());
1998 if (isScope) {
1999 fIndentation--;
2000 this->writeLine("}");
2001 }
2002 }
2003
writeExpressionStatement(const Expression & expr)2004 void WGSLCodeGenerator::writeExpressionStatement(const Expression& expr) {
2005 // Any expression-related side effects must be emitted as separate statements when
2006 // `assembleExpression` is called.
2007 // The final result of the expression will be a variable, let-reference, or an expression with
2008 // no side effects (`foo + bar`). Discarding this result is safe, as the program never uses it.
2009 (void)this->assembleExpression(expr, Precedence::kStatement);
2010 }
2011
writeDoStatement(const DoStatement & s)2012 void WGSLCodeGenerator::writeDoStatement(const DoStatement& s) {
2013 // Generate a loop structure like this:
2014 // loop {
2015 // body-statement;
2016 // continuing {
2017 // break if inverted-test-expression;
2018 // }
2019 // }
2020
2021 ++fConditionalScopeDepth;
2022
2023 std::unique_ptr<Expression> invertedTestExpr = PrefixExpression::Make(
2024 fContext, s.test()->fPosition, OperatorKind::LOGICALNOT, s.test()->clone());
2025
2026 this->writeLine("loop {");
2027 fIndentation++;
2028 this->writeStatement(*s.statement());
2029 this->finishLine();
2030
2031 this->writeLine("continuing {");
2032 fIndentation++;
2033 std::string breakIfExpr = this->assembleExpression(*invertedTestExpr, Precedence::kExpression);
2034 this->write("break if ");
2035 this->write(breakIfExpr);
2036 this->writeLine(";");
2037 fIndentation--;
2038 this->writeLine("}");
2039 fIndentation--;
2040 this->writeLine("}");
2041
2042 --fConditionalScopeDepth;
2043 }
2044
writeForStatement(const ForStatement & s)2045 void WGSLCodeGenerator::writeForStatement(const ForStatement& s) {
2046 // Generate a loop structure wrapped in an extra scope:
2047 // {
2048 // initializer-statement;
2049 // loop;
2050 // }
2051 // The outer scope is necessary to prevent the initializer-variable from leaking out into the
2052 // rest of the code. In practice, the generated code actually tends to be scoped even more
2053 // deeply, as the body-statement almost always contributes an extra block.
2054
2055 ++fConditionalScopeDepth;
2056
2057 if (s.initializer()) {
2058 this->writeLine("{");
2059 fIndentation++;
2060 this->writeStatement(*s.initializer());
2061 this->writeLine();
2062 }
2063
2064 this->writeLine("loop {");
2065 fIndentation++;
2066
2067 if (s.unrollInfo()) {
2068 if (s.unrollInfo()->fCount <= 0) {
2069 // Loops which are known to never execute don't need to be emitted at all.
2070 // (However, the front end should have already replaced this loop with a Nop.)
2071 } else {
2072 // Loops which are known to execute at least once can use this form:
2073 //
2074 // loop {
2075 // body-statement;
2076 // continuing {
2077 // next-expression;
2078 // break if inverted-test-expression;
2079 // }
2080 // }
2081
2082 this->writeStatement(*s.statement());
2083 this->finishLine();
2084 this->writeLine("continuing {");
2085 ++fIndentation;
2086
2087 if (s.next()) {
2088 this->writeExpressionStatement(*s.next());
2089 this->finishLine();
2090 }
2091
2092 if (s.test()) {
2093 std::unique_ptr<Expression> invertedTestExpr = PrefixExpression::Make(
2094 fContext, s.test()->fPosition, OperatorKind::LOGICALNOT, s.test()->clone());
2095
2096 std::string breakIfExpr =
2097 this->assembleExpression(*invertedTestExpr, Precedence::kExpression);
2098 this->write("break if ");
2099 this->write(breakIfExpr);
2100 this->writeLine(";");
2101 }
2102
2103 --fIndentation;
2104 this->writeLine("}");
2105 }
2106 } else {
2107 // Loops without a known execution count are emitted in this form:
2108 //
2109 // loop {
2110 // if test-expression {
2111 // body-statement;
2112 // } else {
2113 // break;
2114 // }
2115 // continuing {
2116 // next-expression;
2117 // }
2118 // }
2119
2120 if (s.test()) {
2121 std::string testExpr = this->assembleExpression(*s.test(), Precedence::kExpression);
2122 this->write("if ");
2123 this->write(testExpr);
2124 this->writeLine(" {");
2125
2126 fIndentation++;
2127 this->writeStatement(*s.statement());
2128 this->finishLine();
2129 fIndentation--;
2130
2131 this->writeLine("} else {");
2132
2133 fIndentation++;
2134 this->writeLine("break;");
2135 fIndentation--;
2136
2137 this->writeLine("}");
2138 }
2139 else {
2140 this->writeStatement(*s.statement());
2141 this->finishLine();
2142 }
2143
2144 if (s.next()) {
2145 this->writeLine("continuing {");
2146 fIndentation++;
2147 this->writeExpressionStatement(*s.next());
2148 this->finishLine();
2149 fIndentation--;
2150 this->writeLine("}");
2151 }
2152 }
2153
2154 // This matches an open-brace at the top of the loop.
2155 fIndentation--;
2156 this->writeLine("}");
2157
2158 if (s.initializer()) {
2159 // This matches an open-brace before the initializer-statement.
2160 fIndentation--;
2161 this->writeLine("}");
2162 }
2163
2164 --fConditionalScopeDepth;
2165 }
2166
writeIfStatement(const IfStatement & s)2167 void WGSLCodeGenerator::writeIfStatement(const IfStatement& s) {
2168 ++fConditionalScopeDepth;
2169
2170 std::string testExpr = this->assembleExpression(*s.test(), Precedence::kExpression);
2171 this->write("if ");
2172 this->write(testExpr);
2173 this->writeLine(" {");
2174 fIndentation++;
2175 this->writeStatement(*s.ifTrue());
2176 this->finishLine();
2177 fIndentation--;
2178 if (s.ifFalse()) {
2179 this->writeLine("} else {");
2180 fIndentation++;
2181 this->writeStatement(*s.ifFalse());
2182 this->finishLine();
2183 fIndentation--;
2184 }
2185 this->writeLine("}");
2186
2187 --fConditionalScopeDepth;
2188 }
2189
writeReturnStatement(const ReturnStatement & s)2190 void WGSLCodeGenerator::writeReturnStatement(const ReturnStatement& s) {
2191 fHasUnconditionalReturn |= (fConditionalScopeDepth == 0);
2192
2193 std::string expr = s.expression()
2194 ? this->assembleExpression(*s.expression(), Precedence::kExpression)
2195 : std::string();
2196 this->write("return ");
2197 this->write(expr);
2198 this->write(";");
2199 }
2200
writeSwitchCaseList(SkSpan<const SwitchCase * const> cases)2201 void WGSLCodeGenerator::writeSwitchCaseList(SkSpan<const SwitchCase* const> cases) {
2202 auto separator = SkSL::String::Separator();
2203 for (const SwitchCase* const sc : cases) {
2204 this->write(separator());
2205 if (sc->isDefault()) {
2206 this->write("default");
2207 } else {
2208 this->write(std::to_string(sc->value()));
2209 }
2210 }
2211 }
2212
writeSwitchCases(SkSpan<const SwitchCase * const> cases)2213 void WGSLCodeGenerator::writeSwitchCases(SkSpan<const SwitchCase* const> cases) {
2214 if (!cases.empty()) {
2215 // Only the last switch-case should have a non-empty statement.
2216 SkASSERT(std::all_of(cases.begin(), std::prev(cases.end()), [](const SwitchCase* sc) {
2217 return sc->statement()->isEmpty();
2218 }));
2219
2220 // Emit the cases in a comma-separated list.
2221 this->write("case ");
2222 this->writeSwitchCaseList(cases);
2223 this->writeLine(" {");
2224 ++fIndentation;
2225
2226 // Emit the switch-case body.
2227 this->writeStatement(*cases.back()->statement());
2228 this->finishLine();
2229
2230 --fIndentation;
2231 this->writeLine("}");
2232 }
2233 }
2234
writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase * const> cases,std::string_view switchValue)2235 void WGSLCodeGenerator::writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase* const> cases,
2236 std::string_view switchValue) {
2237 // There's no need for fallthrough handling unless we actually have multiple case blocks.
2238 if (cases.size() < 2) {
2239 this->writeSwitchCases(cases);
2240 return;
2241 }
2242
2243 // Match against the entire case group.
2244 this->write("case ");
2245 this->writeSwitchCaseList(cases);
2246 this->writeLine(" {");
2247 ++fIndentation;
2248
2249 std::string fallthroughVar = this->writeScratchVar(*fContext.fTypes.fBool, "false");
2250 const size_t secondToLastCaseIndex = cases.size() - 2;
2251 const size_t lastCaseIndex = cases.size() - 1;
2252
2253 for (size_t index = 0; index < cases.size(); ++index) {
2254 const SwitchCase& sc = *cases[index];
2255 if (index < lastCaseIndex) {
2256 // The default case must come last in SkSL, and this case isn't the last one, so it
2257 // can't possibly be the default.
2258 SkASSERT(!sc.isDefault());
2259
2260 this->write("if ");
2261 if (index > 0) {
2262 this->write(fallthroughVar);
2263 this->write(" || ");
2264 }
2265 this->write(switchValue);
2266 this->write(" == ");
2267 this->write(std::to_string(sc.value()));
2268 this->writeLine(" {");
2269 fIndentation++;
2270
2271 // We write the entire case-block statement here, and then set `switchFallthrough`
2272 // to 1. If the case-block had a break statement in it, we break out of the outer
2273 // for-loop entirely, meaning the `switchFallthrough` assignment never occurs, nor
2274 // does any code after it inside the switch. We've forbidden `continue` statements
2275 // inside switch case-blocks entirely, so we don't need to consider their effect on
2276 // control flow; see the Finalizer in FunctionDefinition::Convert.
2277 this->writeStatement(*sc.statement());
2278 this->finishLine();
2279
2280 if (index < secondToLastCaseIndex) {
2281 // Set a variable to indicate falling through to the next block. The very last
2282 // case-block is reached by process of elimination and doesn't need this
2283 // variable, so we don't actually need to set it if we are on the second-to-last
2284 // case block.
2285 this->write(fallthroughVar);
2286 this->write(" = true; ");
2287 }
2288 this->writeLine("// fallthrough");
2289
2290 fIndentation--;
2291 this->writeLine("}");
2292 } else {
2293 // This is the final case. Since it's always last, we can just dump in the code.
2294 // (If we didn't match any of the other values, we must have matched this one by
2295 // process of elimination. If we did match one of the other values, we either hit a
2296 // `break` statement earlier--and won't get this far--or we're falling through.)
2297 this->writeStatement(*sc.statement());
2298 this->finishLine();
2299 }
2300 }
2301
2302 --fIndentation;
2303 this->writeLine("}");
2304 }
2305
writeSwitchStatement(const SwitchStatement & s)2306 void WGSLCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2307 // WGSL supports the `switch` statement in a limited capacity. A default case must always be
2308 // specified. Each switch-case must be scoped inside braces. Fallthrough is not supported; a
2309 // trailing break is implied at the end of each switch-case block. (Explicit breaks are also
2310 // allowed.) One minor improvement over a traditional switch is that switch-cases take a list
2311 // of values to match, instead of a single value:
2312 // case 1, 2 { foo(); }
2313 // case 3, default { bar(); }
2314 //
2315 // We will use the native WGSL switch statement for any switch-cases in the SkSL which can be
2316 // made to conform to these limitations. The remaining cases which cannot conform will be
2317 // emulated with if-else blocks (similar to our GLSL ES2 switch-statement emulation path). This
2318 // should give us good performance in the common case, since most switches naturally conform.
2319
2320 // First, let's emit the switch itself.
2321 std::string valueExpr = this->writeNontrivialScratchLet(*s.value(), Precedence::kExpression);
2322 this->write("switch ");
2323 this->write(valueExpr);
2324 this->writeLine(" {");
2325 ++fIndentation;
2326
2327 // Now let's go through the switch-cases, and emit the ones that don't fall through.
2328 TArray<const SwitchCase*> nativeCases;
2329 TArray<const SwitchCase*> fallthroughCases;
2330 bool previousCaseFellThrough = false;
2331 bool foundNativeDefault = false;
2332 [[maybe_unused]] bool foundFallthroughDefault = false;
2333
2334 const int lastSwitchCaseIdx = s.cases().size() - 1;
2335 for (int index = 0; index <= lastSwitchCaseIdx; ++index) {
2336 const SwitchCase& sc = s.cases()[index]->as<SwitchCase>();
2337
2338 if (sc.statement()->isEmpty()) {
2339 // This is a `case X:` that immediately falls through to the next case.
2340 // If we aren't already falling through, we can handle this via a comma-separated list.
2341 if (previousCaseFellThrough) {
2342 fallthroughCases.push_back(&sc);
2343 foundFallthroughDefault |= sc.isDefault();
2344 } else {
2345 nativeCases.push_back(&sc);
2346 foundNativeDefault |= sc.isDefault();
2347 }
2348 continue;
2349 }
2350
2351 if (index == lastSwitchCaseIdx || Analysis::SwitchCaseContainsUnconditionalExit(sc)) {
2352 // This is a `case X:` that never falls through.
2353 if (previousCaseFellThrough) {
2354 // Because the previous cases fell through, we can't use a native switch-case here.
2355 fallthroughCases.push_back(&sc);
2356 foundFallthroughDefault |= sc.isDefault();
2357
2358 this->writeEmulatedSwitchFallthroughCases(fallthroughCases, valueExpr);
2359 fallthroughCases.clear();
2360
2361 // Fortunately, we're no longer falling through blocks, so we might be able to use a
2362 // native switch-case list again.
2363 previousCaseFellThrough = false;
2364 } else {
2365 // Emit a native switch-case block with a comma-separated case list.
2366 nativeCases.push_back(&sc);
2367 foundNativeDefault |= sc.isDefault();
2368
2369 this->writeSwitchCases(nativeCases);
2370 nativeCases.clear();
2371 }
2372 continue;
2373 }
2374
2375 // This case falls through, so it will need to be handled via emulation.
2376 // If we have put together a collection of "native" cases (cases that fall through with no
2377 // actual case-body), we will need to slide them over into the fallthrough-case list.
2378 fallthroughCases.push_back_n(nativeCases.size(), nativeCases.data());
2379 nativeCases.clear();
2380
2381 fallthroughCases.push_back(&sc);
2382 foundFallthroughDefault |= sc.isDefault();
2383 previousCaseFellThrough = true;
2384 }
2385
2386 // Finish out the remaining switch-cases.
2387 this->writeSwitchCases(nativeCases);
2388 nativeCases.clear();
2389
2390 this->writeEmulatedSwitchFallthroughCases(fallthroughCases, valueExpr);
2391 fallthroughCases.clear();
2392
2393 // WGSL requires a default case.
2394 if (!foundNativeDefault && !foundFallthroughDefault) {
2395 this->writeLine("case default {}");
2396 }
2397
2398 --fIndentation;
2399 this->writeLine("}");
2400 }
2401
writeVarDeclaration(const VarDeclaration & varDecl)2402 void WGSLCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2403 std::string initialValue =
2404 varDecl.value() ? this->assembleExpression(*varDecl.value(), Precedence::kAssignment)
2405 : std::string();
2406
2407 if (varDecl.var()->modifierFlags().isConst() ||
2408 (fProgram.fUsage->get(*varDecl.var()).fWrite == 1 && varDecl.value())) {
2409 // Use `const` at global scope, or if the value is a compile-time constant.
2410 SkASSERTF(varDecl.value(), "an immutable variable must specify a value");
2411 const bool useConst =
2412 !fAtFunctionScope || Analysis::IsCompileTimeConstant(*varDecl.value());
2413 this->write(useConst ? "const " : "let ");
2414 } else {
2415 this->write("var ");
2416 }
2417 this->write(this->assembleName(varDecl.var()->mangledName()));
2418 this->write(": ");
2419 this->write(to_wgsl_type(fContext, varDecl.var()->type(), &varDecl.var()->layout()));
2420
2421 if (varDecl.value()) {
2422 this->write(" = ");
2423 this->write(initialValue);
2424 }
2425
2426 this->write(";");
2427 }
2428
makeLValue(const Expression & e)2429 std::unique_ptr<WGSLCodeGenerator::LValue> WGSLCodeGenerator::makeLValue(const Expression& e) {
2430 if (e.is<VariableReference>()) {
2431 return std::make_unique<PointerLValue>(
2432 this->variableReferenceNameForLValue(e.as<VariableReference>()));
2433 }
2434 if (e.is<FieldAccess>()) {
2435 return std::make_unique<PointerLValue>(this->assembleFieldAccess(e.as<FieldAccess>()));
2436 }
2437 if (e.is<IndexExpression>()) {
2438 const IndexExpression& idx = e.as<IndexExpression>();
2439 if (idx.base()->type().isVector()) {
2440 // Rewrite indexed-swizzle accesses like `myVec.zyx[i]` into an index onto `myVec`.
2441 if (std::unique_ptr<Expression> rewrite =
2442 Transform::RewriteIndexedSwizzle(fContext, idx)) {
2443 return std::make_unique<VectorComponentLValue>(
2444 this->assembleExpression(*rewrite, Precedence::kAssignment));
2445 } else {
2446 return std::make_unique<VectorComponentLValue>(this->assembleIndexExpression(idx));
2447 }
2448 } else {
2449 return std::make_unique<PointerLValue>(this->assembleIndexExpression(idx));
2450 }
2451 }
2452 if (e.is<Swizzle>()) {
2453 const Swizzle& swizzle = e.as<Swizzle>();
2454 if (swizzle.components().size() == 1) {
2455 return std::make_unique<VectorComponentLValue>(this->assembleSwizzle(swizzle));
2456 } else {
2457 return std::make_unique<SwizzleLValue>(
2458 fContext,
2459 this->assembleExpression(*swizzle.base(), Precedence::kAssignment),
2460 swizzle.base()->type(),
2461 swizzle.components());
2462 }
2463 }
2464
2465 fContext.fErrors->error(e.fPosition, "unsupported lvalue type");
2466 return nullptr;
2467 }
2468
assembleExpression(const Expression & e,Precedence parentPrecedence)2469 std::string WGSLCodeGenerator::assembleExpression(const Expression& e,
2470 Precedence parentPrecedence) {
2471 switch (e.kind()) {
2472 case Expression::Kind::kBinary:
2473 return this->assembleBinaryExpression(e.as<BinaryExpression>(), parentPrecedence);
2474
2475 case Expression::Kind::kConstructorCompound:
2476 return this->assembleConstructorCompound(e.as<ConstructorCompound>());
2477
2478 case Expression::Kind::kConstructorArrayCast:
2479 // This is a no-op, since WGSL 1.0 doesn't have any concept of precision qualifiers.
2480 // When we add support for f16, this will need to copy the array contents.
2481 return this->assembleExpression(*e.as<ConstructorArrayCast>().argument(),
2482 parentPrecedence);
2483
2484 case Expression::Kind::kConstructorArray:
2485 case Expression::Kind::kConstructorCompoundCast:
2486 case Expression::Kind::kConstructorScalarCast:
2487 case Expression::Kind::kConstructorSplat:
2488 case Expression::Kind::kConstructorStruct:
2489 return this->assembleAnyConstructor(e.asAnyConstructor());
2490
2491 case Expression::Kind::kConstructorDiagonalMatrix:
2492 return this->assembleConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>());
2493
2494 case Expression::Kind::kConstructorMatrixResize:
2495 return this->assembleConstructorMatrixResize(e.as<ConstructorMatrixResize>());
2496
2497 case Expression::Kind::kEmpty:
2498 return "false";
2499
2500 case Expression::Kind::kFieldAccess:
2501 return this->assembleFieldAccess(e.as<FieldAccess>());
2502
2503 case Expression::Kind::kFunctionCall:
2504 return this->assembleFunctionCall(e.as<FunctionCall>(), parentPrecedence);
2505
2506 case Expression::Kind::kIndex:
2507 return this->assembleIndexExpression(e.as<IndexExpression>());
2508
2509 case Expression::Kind::kLiteral:
2510 return this->assembleLiteral(e.as<Literal>());
2511
2512 case Expression::Kind::kPrefix:
2513 return this->assemblePrefixExpression(e.as<PrefixExpression>(), parentPrecedence);
2514
2515 case Expression::Kind::kPostfix:
2516 return this->assemblePostfixExpression(e.as<PostfixExpression>(), parentPrecedence);
2517
2518 case Expression::Kind::kSetting:
2519 return this->assembleExpression(*e.as<Setting>().toLiteral(fCaps), parentPrecedence);
2520
2521 case Expression::Kind::kSwizzle:
2522 return this->assembleSwizzle(e.as<Swizzle>());
2523
2524 case Expression::Kind::kTernary:
2525 return this->assembleTernaryExpression(e.as<TernaryExpression>(), parentPrecedence);
2526
2527 case Expression::Kind::kVariableReference:
2528 return this->assembleVariableReference(e.as<VariableReference>());
2529
2530 default:
2531 SkDEBUGFAILF("unsupported expression:\n%s", e.description().c_str());
2532 return {};
2533 }
2534 }
2535
is_nontrivial_expression(const Expression & expr)2536 static bool is_nontrivial_expression(const Expression& expr) {
2537 // We consider a "trivial expression" one which we can repeat multiple times in the output
2538 // without being dangerous or spammy. We avoid emitting temporary variables for very trivial
2539 // expressions: literals, unadorned variable references, or constant vectors.
2540 if (expr.is<VariableReference>() || expr.is<Literal>()) {
2541 // Variables and literals are trivial; adding a let-declaration won't simplify anything.
2542 return false;
2543 }
2544 if (expr.type().isVector() && Analysis::IsConstantExpression(expr)) {
2545 // Compile-time constant vectors are also considered trivial; they're short and sweet.
2546 return false;
2547 }
2548 return true;
2549 }
2550
binary_op_is_ambiguous_in_wgsl(Operator op)2551 static bool binary_op_is_ambiguous_in_wgsl(Operator op) {
2552 // WGSL always requires parentheses for some operators which are deemed to be ambiguous.
2553 // (8.19. Operator Precedence and Associativity)
2554 switch (op.kind()) {
2555 case OperatorKind::LOGICALOR:
2556 case OperatorKind::LOGICALAND:
2557 case OperatorKind::BITWISEOR:
2558 case OperatorKind::BITWISEAND:
2559 case OperatorKind::BITWISEXOR:
2560 case OperatorKind::SHL:
2561 case OperatorKind::SHR:
2562 case OperatorKind::LT:
2563 case OperatorKind::GT:
2564 case OperatorKind::LTEQ:
2565 case OperatorKind::GTEQ:
2566 return true;
2567
2568 default:
2569 return false;
2570 }
2571 }
2572
binaryOpNeedsComponentwiseMatrixPolyfill(const Type & left,const Type & right,Operator op)2573 bool WGSLCodeGenerator::binaryOpNeedsComponentwiseMatrixPolyfill(const Type& left,
2574 const Type& right,
2575 Operator op) {
2576 switch (op.kind()) {
2577 case OperatorKind::SLASH:
2578 // WGSL does not natively support componentwise matrix-op-matrix for division.
2579 if (left.isMatrix() && right.isMatrix()) {
2580 return true;
2581 }
2582 [[fallthrough]];
2583
2584 case OperatorKind::PLUS:
2585 case OperatorKind::MINUS:
2586 // WGSL does not natively support componentwise matrix-op-scalar or scalar-op-matrix for
2587 // addition, subtraction or division.
2588 return (left.isMatrix() && right.isScalar()) ||
2589 (left.isScalar() && right.isMatrix());
2590
2591 default:
2592 return false;
2593 }
2594 }
2595
assembleBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)2596 std::string WGSLCodeGenerator::assembleBinaryExpression(const BinaryExpression& b,
2597 Precedence parentPrecedence) {
2598 return this->assembleBinaryExpression(*b.left(), b.getOperator(), *b.right(), b.type(),
2599 parentPrecedence);
2600 }
2601
assembleBinaryExpression(const Expression & left,Operator op,const Expression & right,const Type & resultType,Precedence parentPrecedence)2602 std::string WGSLCodeGenerator::assembleBinaryExpression(const Expression& left,
2603 Operator op,
2604 const Expression& right,
2605 const Type& resultType,
2606 Precedence parentPrecedence) {
2607 // If the operator is && or ||, we need to handle short-circuiting properly. Specifically, we
2608 // sometimes need to emit extra statements to paper over functionality that WGSL lacks, like
2609 // assignment in the middle of an expression. We need to guard those extra statements, to ensure
2610 // that they don't occur if the expression evaluation is short-circuited. Converting the
2611 // expression into an if-else block keeps the short-circuit property intact even when extra
2612 // statements are involved.
2613 // If the RHS doesn't have any side effects, then it's safe to just leave the expression as-is,
2614 // since we know that any possible extra statements are non-side-effecting.
2615 std::string expr;
2616 if (op.kind() == OperatorKind::LOGICALAND && Analysis::HasSideEffects(right)) {
2617 // Converts `left_expression && right_expression` into the following block:
2618
2619 // var _skTemp1: bool;
2620 // [[ prepare left_expression ]]
2621 // if left_expression {
2622 // [[ prepare right_expression ]]
2623 // _skTemp1 = right_expression;
2624 // } else {
2625 // _skTemp1 = false;
2626 // }
2627
2628 expr = this->writeScratchVar(resultType);
2629
2630 std::string leftExpr = this->assembleExpression(left, Precedence::kExpression);
2631 this->write("if ");
2632 this->write(leftExpr);
2633 this->writeLine(" {");
2634
2635 ++fIndentation;
2636 std::string rightExpr = this->assembleExpression(right, Precedence::kAssignment);
2637 this->write(expr);
2638 this->write(" = ");
2639 this->write(rightExpr);
2640 this->writeLine(";");
2641 --fIndentation;
2642
2643 this->writeLine("} else {");
2644
2645 ++fIndentation;
2646 this->write(expr);
2647 this->writeLine(" = false;");
2648 --fIndentation;
2649
2650 this->writeLine("}");
2651 return expr;
2652 }
2653
2654 if (op.kind() == OperatorKind::LOGICALOR && Analysis::HasSideEffects(right)) {
2655 // Converts `left_expression || right_expression` into the following block:
2656
2657 // var _skTemp1: bool;
2658 // [[ prepare left_expression ]]
2659 // if left_expression {
2660 // _skTemp1 = true;
2661 // } else {
2662 // [[ prepare right_expression ]]
2663 // _skTemp1 = right_expression;
2664 // }
2665
2666 expr = this->writeScratchVar(resultType);
2667
2668 std::string leftExpr = this->assembleExpression(left, Precedence::kExpression);
2669 this->write("if ");
2670 this->write(leftExpr);
2671 this->writeLine(" {");
2672
2673 ++fIndentation;
2674 this->write(expr);
2675 this->writeLine(" = true;");
2676 --fIndentation;
2677
2678 this->writeLine("} else {");
2679
2680 ++fIndentation;
2681 std::string rightExpr = this->assembleExpression(right, Precedence::kAssignment);
2682 this->write(expr);
2683 this->write(" = ");
2684 this->write(rightExpr);
2685 this->writeLine(";");
2686 --fIndentation;
2687
2688 this->writeLine("}");
2689 return expr;
2690 }
2691
2692 // Handle comma-expressions.
2693 if (op.kind() == OperatorKind::COMMA) {
2694 // The result from the left-expression is ignored, but its side effects must occur.
2695 this->assembleExpression(left, Precedence::kStatement);
2696
2697 // Evaluate the right side normally.
2698 return this->assembleExpression(right, parentPrecedence);
2699 }
2700
2701 // Handle assignment-expressions.
2702 if (op.isAssignment()) {
2703 std::unique_ptr<LValue> lvalue = this->makeLValue(left);
2704 if (!lvalue) {
2705 return "";
2706 }
2707
2708 if (op.kind() == OperatorKind::EQ) {
2709 // Evaluate the right-hand side of simple assignment (`a = b` --> `b`).
2710 expr = this->assembleExpression(right, Precedence::kAssignment);
2711 } else {
2712 // Evaluate the right-hand side of compound-assignment (`a += b` --> `a + b`).
2713 op = op.removeAssignment();
2714
2715 std::string lhs = lvalue->load();
2716 std::string rhs = this->assembleExpression(right, op.getBinaryPrecedence());
2717
2718 if (this->binaryOpNeedsComponentwiseMatrixPolyfill(left.type(), right.type(), op)) {
2719 if (is_nontrivial_expression(right)) {
2720 rhs = this->writeScratchLet(rhs);
2721 }
2722
2723 expr = this->assembleComponentwiseMatrixBinary(left.type(), right.type(),
2724 lhs, rhs, op);
2725 } else {
2726 expr = lhs + operator_name(op) + rhs;
2727 }
2728 }
2729
2730 // Emit the assignment statement (`a = a + b`).
2731 this->writeLine(lvalue->store(expr));
2732
2733 // Return the lvalue (`a`) as the result, since the value might be used by the caller.
2734 return lvalue->load();
2735 }
2736
2737 if (op.isEquality()) {
2738 return this->assembleEqualityExpression(left, right, op, parentPrecedence);
2739 }
2740
2741 Precedence precedence = op.getBinaryPrecedence();
2742 bool needParens = precedence >= parentPrecedence;
2743 if (binary_op_is_ambiguous_in_wgsl(op)) {
2744 precedence = Precedence::kParentheses;
2745 }
2746 if (needParens) {
2747 expr = "(";
2748 }
2749
2750 // If we are emitting `constant + constant`, this generally indicates that the values could not
2751 // be constant-folded. This happens when the values overflow or become nan. WGSL will refuse to
2752 // compile such expressions, as WGSL 1.0 has no infinity/nan support. However, the WGSL
2753 // compile-time check can be dodged by putting one side into a let-variable. This technically
2754 // gives us an indeterminate result, but the vast majority of backends will just calculate an
2755 // infinity or nan here, as we would expect. (skia:14385)
2756 bool bothSidesConstant = ConstantFolder::GetConstantValueOrNull(left) &&
2757 ConstantFolder::GetConstantValueOrNull(right);
2758
2759 std::string lhs = this->assembleExpression(left, precedence);
2760 std::string rhs = this->assembleExpression(right, precedence);
2761
2762 if (this->binaryOpNeedsComponentwiseMatrixPolyfill(left.type(), right.type(), op)) {
2763 if (bothSidesConstant || is_nontrivial_expression(left)) {
2764 lhs = this->writeScratchLet(lhs);
2765 }
2766 if (is_nontrivial_expression(right)) {
2767 rhs = this->writeScratchLet(rhs);
2768 }
2769
2770 expr += this->assembleComponentwiseMatrixBinary(left.type(), right.type(), lhs, rhs, op);
2771 } else {
2772 if (bothSidesConstant) {
2773 lhs = this->writeScratchLet(lhs);
2774 }
2775
2776 expr += lhs + operator_name(op) + rhs;
2777 }
2778
2779 if (needParens) {
2780 expr += ')';
2781 }
2782
2783 return expr;
2784 }
2785
assembleFieldAccess(const FieldAccess & f)2786 std::string WGSLCodeGenerator::assembleFieldAccess(const FieldAccess& f) {
2787 const Field* field = &f.base()->type().fields()[f.fieldIndex()];
2788 std::string expr;
2789
2790 if (FieldPolyfillInfo* polyfillInfo = fFieldPolyfillMap.find(field)) {
2791 // We found a matrix uniform. We are required to pass some matrix uniforms as array vectors,
2792 // since the std140 layout for a matrix assumes 4-column vectors for each row, and WGSL
2793 // tightly packs 2-column matrices. When emitting code, we replace the field-access
2794 // expression with a global variable which holds an unpacked version of the uniform.
2795 polyfillInfo->fWasAccessed = true;
2796
2797 // The polyfill can either be based directly onto a uniform in an interface block, or it
2798 // might be based on an index-expression onto a uniform if the interface block is arrayed.
2799 const Expression* base = f.base().get();
2800 const IndexExpression* indexExpr = nullptr;
2801 if (base->is<IndexExpression>()) {
2802 indexExpr = &base->as<IndexExpression>();
2803 base = indexExpr->base().get();
2804 }
2805
2806 SkASSERT(base->is<VariableReference>());
2807 expr = polyfillInfo->fReplacementName;
2808
2809 // If we had an index expression, we must append the index.
2810 if (indexExpr) {
2811 expr += '[';
2812 expr += this->assembleExpression(*indexExpr->index(), Precedence::kSequence);
2813 expr += ']';
2814 }
2815 return expr;
2816 }
2817
2818 switch (f.ownerKind()) {
2819 case FieldAccess::OwnerKind::kDefault:
2820 expr = this->assembleExpression(*f.base(), Precedence::kPostfix) + '.';
2821 break;
2822
2823 case FieldAccess::OwnerKind::kAnonymousInterfaceBlock:
2824 if (f.base()->is<VariableReference>() &&
2825 field->fLayout.fBuiltin != SK_POINTSIZE_BUILTIN) {
2826 expr = this->variablePrefix(*f.base()->as<VariableReference>().variable());
2827 }
2828 break;
2829 }
2830
2831 expr += this->assembleName(field->fName);
2832 return expr;
2833 }
2834
all_arguments_constant(const ExpressionArray & arguments)2835 static bool all_arguments_constant(const ExpressionArray& arguments) {
2836 // Returns true if all arguments in the ExpressionArray are compile-time constants. If we are
2837 // calling an intrinsic and all of its inputs are constant, but we didn't constant-fold it, this
2838 // generally indicates that constant-folding resulted in an infinity or nan. The WGSL compiler
2839 // will reject such an expression with a compile-time error. We can dodge the error, taking on
2840 // the risk of indeterminate behavior instead, by replacing one of the constant values with a
2841 // scratch let-variable. (skia:14385)
2842 for (const std::unique_ptr<Expression>& arg : arguments) {
2843 if (!ConstantFolder::GetConstantValueOrNull(*arg)) {
2844 return false;
2845 }
2846 }
2847 return true;
2848 }
2849
assembleSimpleIntrinsic(std::string_view intrinsicName,const FunctionCall & call)2850 std::string WGSLCodeGenerator::assembleSimpleIntrinsic(std::string_view intrinsicName,
2851 const FunctionCall& call) {
2852 // Invoke the function, passing each function argument.
2853 std::string expr = std::string(intrinsicName);
2854 expr.push_back('(');
2855 const ExpressionArray& args = call.arguments();
2856 auto separator = SkSL::String::Separator();
2857 bool allConstant = all_arguments_constant(call.arguments());
2858 for (int index = 0; index < args.size(); ++index) {
2859 expr += separator();
2860
2861 std::string argument = this->assembleExpression(*args[index], Precedence::kSequence);
2862 if (args[index]->type().isAtomic()) {
2863 // WGSL passes atomic values to intrinsics as pointers.
2864 expr += '&';
2865 expr += argument;
2866 } else if (allConstant && index == 0) {
2867 // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2868 expr += this->writeScratchLet(argument);
2869 } else {
2870 expr += argument;
2871 }
2872 }
2873 expr.push_back(')');
2874
2875 if (call.type().isVoid()) {
2876 this->write(expr);
2877 this->writeLine(";");
2878 return std::string();
2879 } else {
2880 return this->writeScratchLet(expr);
2881 }
2882 }
2883
assembleVectorizedIntrinsic(std::string_view intrinsicName,const FunctionCall & call)2884 std::string WGSLCodeGenerator::assembleVectorizedIntrinsic(std::string_view intrinsicName,
2885 const FunctionCall& call) {
2886 SkASSERT(!call.type().isVoid());
2887
2888 // Invoke the function, passing each function argument.
2889 std::string expr = std::string(intrinsicName);
2890 expr.push_back('(');
2891
2892 auto separator = SkSL::String::Separator();
2893 const ExpressionArray& args = call.arguments();
2894 bool returnsVector = call.type().isVector();
2895 bool allConstant = all_arguments_constant(call.arguments());
2896 for (int index = 0; index < args.size(); ++index) {
2897 expr += separator();
2898
2899 bool vectorize = returnsVector && args[index]->type().isScalar();
2900 if (vectorize) {
2901 expr += to_wgsl_type(fContext, call.type());
2902 expr.push_back('(');
2903 }
2904
2905 // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2906 std::string argument = this->assembleExpression(*args[index], Precedence::kSequence);
2907 expr += (allConstant && index == 0) ? this->writeScratchLet(argument)
2908 : argument;
2909 if (vectorize) {
2910 expr.push_back(')');
2911 }
2912 }
2913 expr.push_back(')');
2914
2915 return this->writeScratchLet(expr);
2916 }
2917
assembleUnaryOpIntrinsic(Operator op,const FunctionCall & call,Precedence parentPrecedence)2918 std::string WGSLCodeGenerator::assembleUnaryOpIntrinsic(Operator op,
2919 const FunctionCall& call,
2920 Precedence parentPrecedence) {
2921 SkASSERT(!call.type().isVoid());
2922
2923 bool needParens = Precedence::kPrefix >= parentPrecedence;
2924
2925 std::string expr;
2926 if (needParens) {
2927 expr.push_back('(');
2928 }
2929
2930 expr += operator_name(op);
2931 expr += this->assembleExpression(*call.arguments()[0], Precedence::kPrefix);
2932
2933 if (needParens) {
2934 expr.push_back(')');
2935 }
2936
2937 return expr;
2938 }
2939
assembleBinaryOpIntrinsic(Operator op,const FunctionCall & call,Precedence parentPrecedence)2940 std::string WGSLCodeGenerator::assembleBinaryOpIntrinsic(Operator op,
2941 const FunctionCall& call,
2942 Precedence parentPrecedence) {
2943 SkASSERT(!call.type().isVoid());
2944
2945 Precedence precedence = op.getBinaryPrecedence();
2946 bool needParens = precedence >= parentPrecedence ||
2947 binary_op_is_ambiguous_in_wgsl(op);
2948 std::string expr;
2949 if (needParens) {
2950 expr.push_back('(');
2951 }
2952
2953 // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2954 std::string argument = this->assembleExpression(*call.arguments()[0], precedence);
2955 expr += all_arguments_constant(call.arguments()) ? this->writeScratchLet(argument)
2956 : argument;
2957 expr += operator_name(op);
2958 expr += this->assembleExpression(*call.arguments()[1], precedence);
2959
2960 if (needParens) {
2961 expr.push_back(')');
2962 }
2963
2964 return expr;
2965 }
2966
2967 // Rewrite a WGSL intrinsic of the form "intrinsicName(in) -> struct" to the SkSL's
2968 // "intrinsicName(in, outField) -> returnField", where outField and returnField are the names of the
2969 // fields in the struct returned by the WGSL intrinsic.
assembleOutAssignedIntrinsic(std::string_view intrinsicName,std::string_view returnField,std::string_view outField,const FunctionCall & call)2970 std::string WGSLCodeGenerator::assembleOutAssignedIntrinsic(std::string_view intrinsicName,
2971 std::string_view returnField,
2972 std::string_view outField,
2973 const FunctionCall& call) {
2974 SkASSERT(call.type().componentType().isNumber());
2975 SkASSERT(call.arguments().size() == 2);
2976 SkASSERT(call.function().parameters()[1]->modifierFlags() & ModifierFlag::kOut);
2977
2978 std::string expr = std::string(intrinsicName);
2979 expr += "(";
2980
2981 // Invoke the intrinsic with the first parameter. Use a scratch-let if argument is a constant
2982 // to dodge WGSL overflow errors. (skia:14385)
2983 std::string argument = this->assembleExpression(*call.arguments()[0], Precedence::kSequence);
2984 expr += ConstantFolder::GetConstantValueOrNull(*call.arguments()[0])
2985 ? this->writeScratchLet(argument) : argument;
2986 expr += ")";
2987 // In WGSL the intrinsic returns a struct; assign it to a local so that its fields can be
2988 // accessed multiple times.
2989 expr = this->writeScratchLet(expr);
2990 expr += ".";
2991
2992 // Store the outField of `expr` to the intended "out" argument
2993 std::unique_ptr<LValue> lvalue = this->makeLValue(*call.arguments()[1]);
2994 if (!lvalue) {
2995 return "";
2996 }
2997 std::string outValue = expr;
2998 outValue += outField;
2999 this->writeLine(lvalue->store(outValue));
3000
3001 // And return the expression accessing the returnField.
3002 expr += returnField;
3003 return expr;
3004 }
3005
assemblePartialSampleCall(std::string_view functionName,const Expression & sampler,const Expression & coords)3006 std::string WGSLCodeGenerator::assemblePartialSampleCall(std::string_view functionName,
3007 const Expression& sampler,
3008 const Expression& coords) {
3009 // This function returns `functionName(inSampler_texture, inSampler_sampler, coords` without a
3010 // terminating comma or close-parenthesis. This allows the caller to add more arguments as
3011 // needed.
3012 SkASSERT(sampler.type().typeKind() == Type::TypeKind::kSampler);
3013 std::string expr = std::string(functionName) + '(';
3014 expr += this->assembleExpression(sampler, Precedence::kSequence);
3015 expr += kTextureSuffix;
3016 expr += ", ";
3017 expr += this->assembleExpression(sampler, Precedence::kSequence);
3018 expr += kSamplerSuffix;
3019 expr += ", ";
3020
3021 // Compute the sample coordinates, dividing out the Z if a vec3 was provided.
3022 SkASSERT(coords.type().isVector());
3023 if (coords.type().columns() == 3) {
3024 // The coordinates were passed as a vec3, so we need to emit `coords.xy / coords.z`.
3025 std::string vec3Coords = this->writeScratchLet(coords, Precedence::kMultiplicative);
3026 expr += vec3Coords + ".xy / " + vec3Coords + ".z";
3027 } else {
3028 // The coordinates should be a plain vec2; emit the expression as-is.
3029 SkASSERT(coords.type().columns() == 2);
3030 expr += this->assembleExpression(coords, Precedence::kSequence);
3031 }
3032
3033 return expr;
3034 }
3035
assembleComponentwiseMatrixBinary(const Type & leftType,const Type & rightType,const std::string & left,const std::string & right,Operator op)3036 std::string WGSLCodeGenerator::assembleComponentwiseMatrixBinary(const Type& leftType,
3037 const Type& rightType,
3038 const std::string& left,
3039 const std::string& right,
3040 Operator op) {
3041 bool leftIsMatrix = leftType.isMatrix();
3042 bool rightIsMatrix = rightType.isMatrix();
3043 const Type& matrixType = leftIsMatrix ? leftType : rightType;
3044
3045 std::string expr = to_wgsl_type(fContext, matrixType) + '(';
3046 auto separator = String::Separator();
3047 int columns = matrixType.columns();
3048 for (int c = 0; c < columns; ++c) {
3049 expr += separator();
3050 expr += left;
3051 if (leftIsMatrix) {
3052 expr += '[';
3053 expr += std::to_string(c);
3054 expr += ']';
3055 }
3056 expr += op.operatorName();
3057 expr += right;
3058 if (rightIsMatrix) {
3059 expr += '[';
3060 expr += std::to_string(c);
3061 expr += ']';
3062 }
3063 }
3064 return expr + ')';
3065 }
3066
assembleIntrinsicCall(const FunctionCall & call,IntrinsicKind kind,Precedence parentPrecedence)3067 std::string WGSLCodeGenerator::assembleIntrinsicCall(const FunctionCall& call,
3068 IntrinsicKind kind,
3069 Precedence parentPrecedence) {
3070 // Be careful: WGSL 1.0 will reject any intrinsic calls which can be constant-evaluated to
3071 // infinity or nan with a compile error. If all arguments to an intrinsic are compile-time
3072 // constants (`all_arguments_constant`), it is safest to copy one argument into a scratch-let so
3073 // that the call will be seen as runtime-evaluated, which defuses the overflow checks.
3074 // Don't worry; a competent driver should still optimize it away.
3075
3076 const ExpressionArray& arguments = call.arguments();
3077 switch (kind) {
3078 case k_atan_IntrinsicKind: {
3079 const char* name = (arguments.size() == 1) ? "atan" : "atan2";
3080 return this->assembleSimpleIntrinsic(name, call);
3081 }
3082 case k_dFdx_IntrinsicKind:
3083 return this->assembleSimpleIntrinsic("dpdx", call);
3084
3085 case k_dFdy_IntrinsicKind:
3086 // TODO(b/294274678): apply RTFlip here
3087 return this->assembleSimpleIntrinsic("dpdy", call);
3088
3089 case k_dot_IntrinsicKind: {
3090 if (arguments[0]->type().isScalar()) {
3091 return this->assembleBinaryOpIntrinsic(OperatorKind::STAR, call, parentPrecedence);
3092 }
3093 return this->assembleSimpleIntrinsic("dot", call);
3094 }
3095 case k_equal_IntrinsicKind:
3096 return this->assembleBinaryOpIntrinsic(OperatorKind::EQEQ, call, parentPrecedence);
3097
3098 case k_faceforward_IntrinsicKind: {
3099 if (arguments[0]->type().isScalar()) {
3100 // select(-N, N, (I * Nref) < 0)
3101 std::string N = this->writeNontrivialScratchLet(*arguments[0],
3102 Precedence::kAssignment);
3103 return this->writeScratchLet(
3104 "select(-" + N + ", " + N + ", " +
3105 this->assembleBinaryExpression(*arguments[1],
3106 OperatorKind::STAR,
3107 *arguments[2],
3108 arguments[1]->type(),
3109 Precedence::kRelational) +
3110 " < 0)");
3111 }
3112 return this->assembleSimpleIntrinsic("faceForward", call);
3113 }
3114 case k_frexp_IntrinsicKind:
3115 // SkSL frexp is "$genType fract = frexp($genType, out $genIType exp)" whereas WGSL
3116 // returns a struct with no out param: "let [fract, exp] = frexp($genType)".
3117 return this->assembleOutAssignedIntrinsic("frexp", "fract", "exp", call);
3118
3119 case k_greaterThan_IntrinsicKind:
3120 return this->assembleBinaryOpIntrinsic(OperatorKind::GT, call, parentPrecedence);
3121
3122 case k_greaterThanEqual_IntrinsicKind:
3123 return this->assembleBinaryOpIntrinsic(OperatorKind::GTEQ, call, parentPrecedence);
3124
3125 case k_inverse_IntrinsicKind:
3126 return this->assembleInversePolyfill(call);
3127
3128 case k_inversesqrt_IntrinsicKind:
3129 return this->assembleSimpleIntrinsic("inverseSqrt", call);
3130
3131 case k_lessThan_IntrinsicKind:
3132 return this->assembleBinaryOpIntrinsic(OperatorKind::LT, call, parentPrecedence);
3133
3134 case k_lessThanEqual_IntrinsicKind:
3135 return this->assembleBinaryOpIntrinsic(OperatorKind::LTEQ, call, parentPrecedence);
3136
3137 case k_matrixCompMult_IntrinsicKind: {
3138 // We use a scratch-let for arg0 to avoid the potential for WGSL overflow. (skia:14385)
3139 std::string arg0 = all_arguments_constant(arguments)
3140 ? this->writeScratchLet(*arguments[0], Precedence::kPostfix)
3141 : this->writeNontrivialScratchLet(*arguments[0], Precedence::kPostfix);
3142 std::string arg1 = this->writeNontrivialScratchLet(*arguments[1], Precedence::kPostfix);
3143 return this->writeScratchLet(
3144 this->assembleComponentwiseMatrixBinary(arguments[0]->type(),
3145 arguments[1]->type(),
3146 arg0,
3147 arg1,
3148 OperatorKind::STAR));
3149 }
3150 case k_mix_IntrinsicKind: {
3151 const char* name = arguments[2]->type().componentType().isBoolean() ? "select" : "mix";
3152 return this->assembleVectorizedIntrinsic(name, call);
3153 }
3154 case k_mod_IntrinsicKind: {
3155 // WGSL has no intrinsic equivalent to `mod`. Synthesize `x - y * floor(x / y)`.
3156 // We can use a scratch-let on one side to dodge WGSL overflow errors. In practice, I
3157 // can't find any values of x or y which would overflow, but it can't hurt. (skia:14385)
3158 std::string arg0 = all_arguments_constant(arguments)
3159 ? this->writeScratchLet(*arguments[0], Precedence::kAdditive)
3160 : this->writeNontrivialScratchLet(*arguments[0], Precedence::kAdditive);
3161 std::string arg1 = this->writeNontrivialScratchLet(*arguments[1],
3162 Precedence::kAdditive);
3163 return this->writeScratchLet(arg0 + " - " + arg1 + " * floor(" +
3164 arg0 + " / " + arg1 + ")");
3165 }
3166
3167 case k_modf_IntrinsicKind:
3168 // SkSL modf is "$genType fract = modf($genType, out $genType whole)" whereas WGSL
3169 // returns a struct with no out param: "let [fract, whole] = modf($genType)".
3170 return this->assembleOutAssignedIntrinsic("modf", "fract", "whole", call);
3171
3172 case k_normalize_IntrinsicKind: {
3173 const char* name = arguments[0]->type().isScalar() ? "sign" : "normalize";
3174 return this->assembleSimpleIntrinsic(name, call);
3175 }
3176 case k_not_IntrinsicKind:
3177 return this->assembleUnaryOpIntrinsic(OperatorKind::LOGICALNOT, call, parentPrecedence);
3178
3179 case k_notEqual_IntrinsicKind:
3180 return this->assembleBinaryOpIntrinsic(OperatorKind::NEQ, call, parentPrecedence);
3181
3182 case k_packHalf2x16_IntrinsicKind:
3183 return this->assembleSimpleIntrinsic("pack2x16float", call);
3184
3185 case k_packSnorm2x16_IntrinsicKind:
3186 return this->assembleSimpleIntrinsic("pack2x16snorm", call);
3187
3188 case k_packSnorm4x8_IntrinsicKind:
3189 return this->assembleSimpleIntrinsic("pack4x8snorm", call);
3190
3191 case k_packUnorm2x16_IntrinsicKind:
3192 return this->assembleSimpleIntrinsic("pack2x16unorm", call);
3193
3194 case k_packUnorm4x8_IntrinsicKind:
3195 return this->assembleSimpleIntrinsic("pack4x8unorm", call);
3196
3197 case k_reflect_IntrinsicKind:
3198 if (arguments[0]->type().isScalar()) {
3199 // I - 2 * N * I * N
3200 // We can use a scratch-let for N to dodge WGSL overflow errors. (skia:14385)
3201 std::string I = this->writeNontrivialScratchLet(*arguments[0],
3202 Precedence::kAdditive);
3203 std::string N = all_arguments_constant(arguments)
3204 ? this->writeScratchLet(*arguments[1], Precedence::kMultiplicative)
3205 : this->writeNontrivialScratchLet(*arguments[1], Precedence::kMultiplicative);
3206 return this->writeScratchLet(String::printf("%s - 2 * %s * %s * %s",
3207 I.c_str(), N.c_str(),
3208 I.c_str(), N.c_str()));
3209 }
3210 return this->assembleSimpleIntrinsic("reflect", call);
3211
3212 case k_refract_IntrinsicKind:
3213 if (arguments[0]->type().isScalar()) {
3214 // WGSL only implements refract for vectors; rather than reimplementing refract from
3215 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
3216 std::string I = this->writeNontrivialScratchLet(*arguments[0],
3217 Precedence::kSequence);
3218 std::string N = this->writeNontrivialScratchLet(*arguments[1],
3219 Precedence::kSequence);
3220 // We can use a scratch-let for Eta to avoid WGSL overflow errors. (skia:14385)
3221 std::string Eta = all_arguments_constant(arguments)
3222 ? this->writeScratchLet(*arguments[2], Precedence::kSequence)
3223 : this->writeNontrivialScratchLet(*arguments[2], Precedence::kSequence);
3224 return this->writeScratchLet(
3225 String::printf("refract(vec2<%s>(%s, 0), vec2<%s>(%s, 0), %s).x",
3226 to_wgsl_type(fContext, arguments[0]->type()).c_str(),
3227 I.c_str(),
3228 to_wgsl_type(fContext, arguments[1]->type()).c_str(),
3229 N.c_str(),
3230 Eta.c_str()));
3231 }
3232 return this->assembleSimpleIntrinsic("refract", call);
3233
3234 case k_sample_IntrinsicKind: {
3235 // Determine if a bias argument was passed in.
3236 SkASSERT(arguments.size() == 2 || arguments.size() == 3);
3237 bool callIncludesBias = (arguments.size() == 3);
3238
3239 if (fProgram.fConfig->fSettings.fSharpenTextures || callIncludesBias) {
3240 // We need to supply a bias argument; this is a separate intrinsic in WGSL.
3241 std::string expr = this->assemblePartialSampleCall("textureSampleBias",
3242 *arguments[0],
3243 *arguments[1]);
3244 expr += ", ";
3245 if (callIncludesBias) {
3246 expr += this->assembleExpression(*arguments[2], Precedence::kAdditive) +
3247 " + ";
3248 }
3249 expr += skstd::to_string(fProgram.fConfig->fSettings.fSharpenTextures
3250 ? kSharpenTexturesBias
3251 : 0.0f);
3252 return expr + ')';
3253 }
3254
3255 // No bias is necessary, so we can call `textureSample` directly.
3256 return this->assemblePartialSampleCall("textureSample",
3257 *arguments[0],
3258 *arguments[1]) + ')';
3259 }
3260 case k_sampleLod_IntrinsicKind: {
3261 std::string expr = this->assemblePartialSampleCall("textureSampleLevel",
3262 *arguments[0],
3263 *arguments[1]);
3264 expr += ", " + this->assembleExpression(*arguments[2], Precedence::kSequence);
3265 return expr + ')';
3266 }
3267 case k_sampleGrad_IntrinsicKind: {
3268 std::string expr = this->assemblePartialSampleCall("textureSampleGrad",
3269 *arguments[0],
3270 *arguments[1]);
3271 expr += ", " + this->assembleExpression(*arguments[2], Precedence::kSequence);
3272 expr += ", " + this->assembleExpression(*arguments[3], Precedence::kSequence);
3273 return expr + ')';
3274 }
3275 case k_textureHeight_IntrinsicKind:
3276 return this->assembleSimpleIntrinsic("textureDimensions", call) + ".y";
3277
3278 case k_textureRead_IntrinsicKind: {
3279 // We need to inject an extra argument for the mip-level. We don't plan on using mipmaps
3280 // in our storage textures, so we can just pass zero.
3281 std::string tex = this->assembleExpression(*arguments[0], Precedence::kSequence);
3282 std::string pos = this->writeScratchLet(*arguments[1], Precedence::kSequence);
3283 return std::string("textureLoad(") + tex + ", " + pos + ", 0)";
3284 }
3285 case k_textureWidth_IntrinsicKind:
3286 return this->assembleSimpleIntrinsic("textureDimensions", call) + ".x";
3287
3288 case k_textureWrite_IntrinsicKind:
3289 return this->assembleSimpleIntrinsic("textureStore", call);
3290
3291 case k_unpackHalf2x16_IntrinsicKind:
3292 return this->assembleSimpleIntrinsic("unpack2x16float", call);
3293
3294 case k_unpackSnorm2x16_IntrinsicKind:
3295 return this->assembleSimpleIntrinsic("unpack2x16snorm", call);
3296
3297 case k_unpackSnorm4x8_IntrinsicKind:
3298 return this->assembleSimpleIntrinsic("unpack4x8snorm", call);
3299
3300 case k_unpackUnorm2x16_IntrinsicKind:
3301 return this->assembleSimpleIntrinsic("unpack2x16unorm", call);
3302
3303 case k_unpackUnorm4x8_IntrinsicKind:
3304 return this->assembleSimpleIntrinsic("unpack4x8unorm", call);
3305
3306 case k_clamp_IntrinsicKind:
3307 case k_max_IntrinsicKind:
3308 case k_min_IntrinsicKind:
3309 case k_smoothstep_IntrinsicKind:
3310 case k_step_IntrinsicKind:
3311 return this->assembleVectorizedIntrinsic(call.function().name(), call);
3312
3313 case k_abs_IntrinsicKind:
3314 case k_acos_IntrinsicKind:
3315 case k_all_IntrinsicKind:
3316 case k_any_IntrinsicKind:
3317 case k_asin_IntrinsicKind:
3318 case k_atomicAdd_IntrinsicKind:
3319 case k_atomicLoad_IntrinsicKind:
3320 case k_atomicStore_IntrinsicKind:
3321 case k_ceil_IntrinsicKind:
3322 case k_cos_IntrinsicKind:
3323 case k_cross_IntrinsicKind:
3324 case k_degrees_IntrinsicKind:
3325 case k_distance_IntrinsicKind:
3326 case k_exp_IntrinsicKind:
3327 case k_exp2_IntrinsicKind:
3328 case k_floor_IntrinsicKind:
3329 case k_fract_IntrinsicKind:
3330 case k_length_IntrinsicKind:
3331 case k_log_IntrinsicKind:
3332 case k_log2_IntrinsicKind:
3333 case k_radians_IntrinsicKind:
3334 case k_pow_IntrinsicKind:
3335 case k_saturate_IntrinsicKind:
3336 case k_sign_IntrinsicKind:
3337 case k_sin_IntrinsicKind:
3338 case k_sqrt_IntrinsicKind:
3339 case k_storageBarrier_IntrinsicKind:
3340 case k_tan_IntrinsicKind:
3341 case k_workgroupBarrier_IntrinsicKind:
3342 default:
3343 return this->assembleSimpleIntrinsic(call.function().name(), call);
3344 }
3345 }
3346
3347 static constexpr char kInverse2x2[] =
3348 "fn mat2_inverse(m: mat2x2<f32>) -> mat2x2<f32> {"
3349 "\n" "return mat2x2<f32>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));"
3350 "\n" "}"
3351 "\n";
3352
3353 static constexpr char kInverse3x3[] =
3354 "fn mat3_inverse(m: mat3x3<f32>) -> mat3x3<f32> {"
3355 "\n" "let a00 = m[0].x; let a01 = m[0].y; let a02 = m[0].z;"
3356 "\n" "let a10 = m[1].x; let a11 = m[1].y; let a12 = m[1].z;"
3357 "\n" "let a20 = m[2].x; let a21 = m[2].y; let a22 = m[2].z;"
3358 "\n" "let b01 = a22*a11 - a12*a21;"
3359 "\n" "let b11 = -a22*a10 + a12*a20;"
3360 "\n" "let b21 = a21*a10 - a11*a20;"
3361 "\n" "let det = a00*b01 + a01*b11 + a02*b21;"
3362 "\n" "return mat3x3<f32>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),"
3363 "\n" "b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),"
3364 "\n" "b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);"
3365 "\n" "}"
3366 "\n";
3367
3368 static constexpr char kInverse4x4[] =
3369 "fn mat4_inverse(m: mat4x4<f32>) -> mat4x4<f32>{"
3370 "\n" "let a00 = m[0].x; let a01 = m[0].y; let a02 = m[0].z; let a03 = m[0].w;"
3371 "\n" "let a10 = m[1].x; let a11 = m[1].y; let a12 = m[1].z; let a13 = m[1].w;"
3372 "\n" "let a20 = m[2].x; let a21 = m[2].y; let a22 = m[2].z; let a23 = m[2].w;"
3373 "\n" "let a30 = m[3].x; let a31 = m[3].y; let a32 = m[3].z; let a33 = m[3].w;"
3374 "\n" "let b00 = a00*a11 - a01*a10;"
3375 "\n" "let b01 = a00*a12 - a02*a10;"
3376 "\n" "let b02 = a00*a13 - a03*a10;"
3377 "\n" "let b03 = a01*a12 - a02*a11;"
3378 "\n" "let b04 = a01*a13 - a03*a11;"
3379 "\n" "let b05 = a02*a13 - a03*a12;"
3380 "\n" "let b06 = a20*a31 - a21*a30;"
3381 "\n" "let b07 = a20*a32 - a22*a30;"
3382 "\n" "let b08 = a20*a33 - a23*a30;"
3383 "\n" "let b09 = a21*a32 - a22*a31;"
3384 "\n" "let b10 = a21*a33 - a23*a31;"
3385 "\n" "let b11 = a22*a33 - a23*a32;"
3386 "\n" "let det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;"
3387 "\n" "return mat4x4<f32>(a11*b11 - a12*b10 + a13*b09,"
3388 "\n" "a02*b10 - a01*b11 - a03*b09,"
3389 "\n" "a31*b05 - a32*b04 + a33*b03,"
3390 "\n" "a22*b04 - a21*b05 - a23*b03,"
3391 "\n" "a12*b08 - a10*b11 - a13*b07,"
3392 "\n" "a00*b11 - a02*b08 + a03*b07,"
3393 "\n" "a32*b02 - a30*b05 - a33*b01,"
3394 "\n" "a20*b05 - a22*b02 + a23*b01,"
3395 "\n" "a10*b10 - a11*b08 + a13*b06,"
3396 "\n" "a01*b08 - a00*b10 - a03*b06,"
3397 "\n" "a30*b04 - a31*b02 + a33*b00,"
3398 "\n" "a21*b02 - a20*b04 - a23*b00,"
3399 "\n" "a11*b07 - a10*b09 - a12*b06,"
3400 "\n" "a00*b09 - a01*b07 + a02*b06,"
3401 "\n" "a31*b01 - a30*b03 - a32*b00,"
3402 "\n" "a20*b03 - a21*b01 + a22*b00) * (1/det);"
3403 "\n" "}"
3404 "\n";
3405
assembleInversePolyfill(const FunctionCall & call)3406 std::string WGSLCodeGenerator::assembleInversePolyfill(const FunctionCall& call) {
3407 const ExpressionArray& arguments = call.arguments();
3408 const Type& type = arguments.front()->type();
3409
3410 // The `inverse` intrinsic should only accept a single-argument square matrix.
3411 // Once we implement f16 support, these polyfills will need to be updated to support `hmat`;
3412 // for the time being, all floats in WGSL are f32, so we don't need to worry about precision.
3413 SkASSERT(arguments.size() == 1);
3414 SkASSERT(type.isMatrix());
3415 SkASSERT(type.rows() == type.columns());
3416
3417 switch (type.slotCount()) {
3418 case 4:
3419 if (!fWrittenInverse2) {
3420 fWrittenInverse2 = true;
3421 fHeader.writeText(kInverse2x2);
3422 }
3423 return this->assembleSimpleIntrinsic("mat2_inverse", call);
3424
3425 case 9:
3426 if (!fWrittenInverse3) {
3427 fWrittenInverse3 = true;
3428 fHeader.writeText(kInverse3x3);
3429 }
3430 return this->assembleSimpleIntrinsic("mat3_inverse", call);
3431
3432 case 16:
3433 if (!fWrittenInverse4) {
3434 fWrittenInverse4 = true;
3435 fHeader.writeText(kInverse4x4);
3436 }
3437 return this->assembleSimpleIntrinsic("mat4_inverse", call);
3438
3439 default:
3440 // We only support square matrices.
3441 SkUNREACHABLE;
3442 }
3443 }
3444
assembleFunctionCall(const FunctionCall & call,Precedence parentPrecedence)3445 std::string WGSLCodeGenerator::assembleFunctionCall(const FunctionCall& call,
3446 Precedence parentPrecedence) {
3447 const FunctionDeclaration& func = call.function();
3448 std::string result;
3449
3450 // Many intrinsics need to be rewritten in WGSL.
3451 if (func.isIntrinsic()) {
3452 return this->assembleIntrinsicCall(call, func.intrinsicKind(), parentPrecedence);
3453 }
3454
3455 // We implement function out-parameters by declaring them as pointers. SkSL follows GLSL's
3456 // out-parameter semantics, in which out-parameters are only written back to the original
3457 // variable after the function's execution is complete (see
3458 // https://www.khronos.org/opengl/wiki/Core_Language_(GLSL)#Parameters).
3459 //
3460 // In addition, SkSL supports swizzles and array index expressions to be passed into
3461 // out-parameters; however, WGSL does not allow taking their address into a pointer.
3462 //
3463 // We support these by using LValues to create temporary copies and then pass pointers to the
3464 // copies. Once the function returns, we copy the values back to the LValue.
3465
3466 // First detect which arguments are passed to out-parameters.
3467 // TODO: rewrite this method in terms of LValues.
3468 const ExpressionArray& args = call.arguments();
3469 SkSpan<Variable* const> params = func.parameters();
3470 SkASSERT(SkToSizeT(args.size()) == params.size());
3471
3472 STArray<16, std::unique_ptr<LValue>> writeback;
3473 STArray<16, std::string> substituteArgument;
3474 writeback.reserve_exact(args.size());
3475 substituteArgument.reserve_exact(args.size());
3476
3477 for (int index = 0; index < args.size(); ++index) {
3478 if (params[index]->modifierFlags() & ModifierFlag::kOut) {
3479 std::unique_ptr<LValue> lvalue = this->makeLValue(*args[index]);
3480 if (params[index]->modifierFlags() & ModifierFlag::kIn) {
3481 // Load the lvalue's contents into the substitute argument.
3482 substituteArgument.push_back(this->writeScratchVar(args[index]->type(),
3483 lvalue->load()));
3484 } else {
3485 // Create a substitute argument, but leave it uninitialized.
3486 substituteArgument.push_back(this->writeScratchVar(args[index]->type()));
3487 }
3488 writeback.push_back(std::move(lvalue));
3489 } else {
3490 substituteArgument.push_back(std::string());
3491 writeback.push_back(nullptr);
3492 }
3493 }
3494
3495 std::string expr = this->assembleName(func.mangledName());
3496 expr.push_back('(');
3497 auto separator = SkSL::String::Separator();
3498
3499 if (std::string funcDepArgs = this->functionDependencyArgs(func); !funcDepArgs.empty()) {
3500 expr += funcDepArgs;
3501 separator();
3502 }
3503
3504 // Pass the function arguments, or any substitutes as needed.
3505 for (int index = 0; index < args.size(); ++index) {
3506 expr += separator();
3507 if (!substituteArgument[index].empty()) {
3508 // We need to take the address of the variable and pass it down as a pointer.
3509 expr += '&' + substituteArgument[index];
3510 } else if (args[index]->type().isSampler()) {
3511 // If the argument is a sampler, we need to pass the texture _and_ its associated
3512 // sampler. (Function parameter lists also convert sampler parameters into a matching
3513 // texture/sampler parameter pair.)
3514 expr += this->assembleExpression(*args[index], Precedence::kSequence);
3515 expr += kTextureSuffix;
3516 expr += ", ";
3517 expr += this->assembleExpression(*args[index], Precedence::kSequence);
3518 expr += kSamplerSuffix;
3519 } else if (args[index]->type().isUnsizedArray()) {
3520 // If the array is in the parameter storage space then manually just pass it through
3521 // since it is already a pointer.
3522 if (args[index]->is<VariableReference>()) {
3523 const Variable* v = args[index]->as<VariableReference>().variable();
3524 // A variable reference to an unsized array should always be a parameter,
3525 // because unsized arrays coming from uniforms will have the `FieldAccess`
3526 // expression type.
3527 SkASSERT(v->storage() == Variable::Storage::kParameter);
3528 expr += this->assembleName(v->mangledName());
3529 } else {
3530 expr += "&(" + this->assembleExpression(*args[index], Precedence::kSequence) + ")";
3531 }
3532 } else {
3533 expr += this->assembleExpression(*args[index], Precedence::kSequence);
3534 }
3535 }
3536 expr += ')';
3537
3538 if (call.type().isVoid()) {
3539 // Making function calls that result in `void` is only valid in on the left side of a
3540 // comma-sequence, or in a top-level statement. Emit the function call as a top-level
3541 // statement and return an empty string, as the result will not be used.
3542 SkASSERT(parentPrecedence >= Precedence::kSequence);
3543 this->write(expr);
3544 this->writeLine(";");
3545 } else {
3546 result = this->writeScratchLet(expr);
3547 }
3548
3549 // Write the substitute arguments back into their lvalues.
3550 for (int index = 0; index < args.size(); ++index) {
3551 if (!substituteArgument[index].empty()) {
3552 this->writeLine(writeback[index]->store(substituteArgument[index]));
3553 }
3554 }
3555
3556 // Return the result of invoking the function.
3557 return result;
3558 }
3559
assembleIndexExpression(const IndexExpression & i)3560 std::string WGSLCodeGenerator::assembleIndexExpression(const IndexExpression& i) {
3561 // Put the index value into a let-expression.
3562 std::string idx = this->writeNontrivialScratchLet(*i.index(), Precedence::kExpression);
3563 return this->assembleExpression(*i.base(), Precedence::kPostfix) + "[" + idx + "]";
3564 }
3565
assembleLiteral(const Literal & l)3566 std::string WGSLCodeGenerator::assembleLiteral(const Literal& l) {
3567 const Type& type = l.type();
3568 if (type.isFloat() || type.isBoolean()) {
3569 return l.description(OperatorPrecedence::kExpression);
3570 }
3571 SkASSERT(type.isInteger());
3572 if (type.matches(*fContext.fTypes.fUInt)) {
3573 return std::to_string(l.intValue() & 0xffffffff) + "u";
3574 } else if (type.matches(*fContext.fTypes.fUShort)) {
3575 return std::to_string(l.intValue() & 0xffff) + "u";
3576 } else {
3577 return std::to_string(l.intValue());
3578 }
3579 }
3580
assembleIncrementExpr(const Type & type)3581 std::string WGSLCodeGenerator::assembleIncrementExpr(const Type& type) {
3582 // `type(`
3583 std::string expr = to_wgsl_type(fContext, type);
3584 expr.push_back('(');
3585
3586 // `1, 1, 1...)`
3587 auto separator = SkSL::String::Separator();
3588 for (int slots = type.slotCount(); slots > 0; --slots) {
3589 expr += separator();
3590 expr += "1";
3591 }
3592 expr.push_back(')');
3593 return expr;
3594 }
3595
assemblePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)3596 std::string WGSLCodeGenerator::assemblePrefixExpression(const PrefixExpression& p,
3597 Precedence parentPrecedence) {
3598 // Unary + does nothing, so we omit it from the output.
3599 Operator op = p.getOperator();
3600 if (op.kind() == Operator::Kind::PLUS) {
3601 return this->assembleExpression(*p.operand(), Precedence::kPrefix);
3602 }
3603
3604 // Pre-increment/decrement expressions have no direct equivalent in WGSL.
3605 if (op.kind() == Operator::Kind::PLUSPLUS || op.kind() == Operator::Kind::MINUSMINUS) {
3606 std::unique_ptr<LValue> lvalue = this->makeLValue(*p.operand());
3607 if (!lvalue) {
3608 return "";
3609 }
3610
3611 // Generate the new value: `lvalue + type(1, 1, 1...)`.
3612 std::string newValue =
3613 lvalue->load() +
3614 (p.getOperator().kind() == Operator::Kind::PLUSPLUS ? " + " : " - ") +
3615 this->assembleIncrementExpr(p.operand()->type());
3616 this->writeLine(lvalue->store(newValue));
3617 return lvalue->load();
3618 }
3619
3620 // WGSL natively supports unary negation/not expressions (!,~,-).
3621 SkASSERT(op.kind() == OperatorKind::LOGICALNOT ||
3622 op.kind() == OperatorKind::BITWISENOT ||
3623 op.kind() == OperatorKind::MINUS);
3624
3625 // The unary negation operator only applies to scalars and vectors. For other mathematical
3626 // objects (such as matrices) we can express it as a multiplication by -1.
3627 std::string expr;
3628 const bool needsNegation = op.kind() == Operator::Kind::MINUS &&
3629 !p.operand()->type().isScalar() && !p.operand()->type().isVector();
3630 const bool needParens = needsNegation || Precedence::kPrefix >= parentPrecedence;
3631
3632 if (needParens) {
3633 expr.push_back('(');
3634 }
3635
3636 if (needsNegation) {
3637 expr += "-1.0 * ";
3638 expr += this->assembleExpression(*p.operand(), Precedence::kMultiplicative);
3639 } else {
3640 expr += p.getOperator().tightOperatorName();
3641 expr += this->assembleExpression(*p.operand(), Precedence::kPrefix);
3642 }
3643
3644 if (needParens) {
3645 expr.push_back(')');
3646 }
3647
3648 return expr;
3649 }
3650
assemblePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)3651 std::string WGSLCodeGenerator::assemblePostfixExpression(const PostfixExpression& p,
3652 Precedence parentPrecedence) {
3653 SkASSERT(p.getOperator().kind() == Operator::Kind::PLUSPLUS ||
3654 p.getOperator().kind() == Operator::Kind::MINUSMINUS);
3655
3656 // Post-increment/decrement expressions have no direct equivalent in WGSL; they do exist as a
3657 // standalone statement for convenience, but these aren't the same as SkSL's post-increments.
3658 std::unique_ptr<LValue> lvalue = this->makeLValue(*p.operand());
3659 if (!lvalue) {
3660 return "";
3661 }
3662
3663 // If the expression is used, create a let-copy of the original value.
3664 // (At statement-level precedence, we know the value is unused and can skip this let-copy.)
3665 std::string originalValue;
3666 if (parentPrecedence != Precedence::kStatement) {
3667 originalValue = this->writeScratchLet(lvalue->load());
3668 }
3669 // Generate the new value: `lvalue + type(1, 1, 1...)`.
3670 std::string newValue = lvalue->load() +
3671 (p.getOperator().kind() == Operator::Kind::PLUSPLUS ? " + " : " - ") +
3672 this->assembleIncrementExpr(p.operand()->type());
3673 this->writeLine(lvalue->store(newValue));
3674
3675 return originalValue;
3676 }
3677
assembleSwizzle(const Swizzle & swizzle)3678 std::string WGSLCodeGenerator::assembleSwizzle(const Swizzle& swizzle) {
3679 return this->assembleExpression(*swizzle.base(), Precedence::kPostfix) + "." +
3680 Swizzle::MaskString(swizzle.components());
3681 }
3682
writeScratchVar(const Type & type,const std::string & value)3683 std::string WGSLCodeGenerator::writeScratchVar(const Type& type, const std::string& value) {
3684 std::string scratchVarName = "_skTemp" + std::to_string(fScratchCount++);
3685 this->write("var ");
3686 this->write(scratchVarName);
3687 this->write(": ");
3688 this->write(to_wgsl_type(fContext, type));
3689 if (!value.empty()) {
3690 this->write(" = ");
3691 this->write(value);
3692 }
3693 this->writeLine(";");
3694 return scratchVarName;
3695 }
3696
writeScratchLet(const std::string & expr,bool isCompileTimeConstant)3697 std::string WGSLCodeGenerator::writeScratchLet(const std::string& expr,
3698 bool isCompileTimeConstant) {
3699 std::string scratchVarName = "_skTemp" + std::to_string(fScratchCount++);
3700 this->write(fAtFunctionScope && !isCompileTimeConstant ? "let " : "const ");
3701 this->write(scratchVarName);
3702 this->write(" = ");
3703 this->write(expr);
3704 this->writeLine(";");
3705 return scratchVarName;
3706 }
3707
writeScratchLet(const Expression & expr,Precedence parentPrecedence)3708 std::string WGSLCodeGenerator::writeScratchLet(const Expression& expr,
3709 Precedence parentPrecedence) {
3710 return this->writeScratchLet(this->assembleExpression(expr, parentPrecedence));
3711 }
3712
writeNontrivialScratchLet(const Expression & expr,Precedence parentPrecedence)3713 std::string WGSLCodeGenerator::writeNontrivialScratchLet(const Expression& expr,
3714 Precedence parentPrecedence) {
3715 std::string result = this->assembleExpression(expr, parentPrecedence);
3716 return is_nontrivial_expression(expr)
3717 ? this->writeScratchLet(result, Analysis::IsCompileTimeConstant(expr))
3718 : result;
3719 }
3720
assembleTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)3721 std::string WGSLCodeGenerator::assembleTernaryExpression(const TernaryExpression& t,
3722 Precedence parentPrecedence) {
3723 std::string expr;
3724
3725 // The trivial case is when neither branch has side effects and evaluate to a scalar or vector
3726 // type. This can be represented with a call to the WGSL `select` intrinsic. Select doesn't
3727 // support short-circuiting, so we should only use it when both the true- and false-expressions
3728 // are trivial to evaluate.
3729 if ((t.type().isScalar() || t.type().isVector()) &&
3730 !Analysis::HasSideEffects(*t.test()) &&
3731 Analysis::IsTrivialExpression(*t.ifTrue()) &&
3732 Analysis::IsTrivialExpression(*t.ifFalse())) {
3733
3734 bool needParens = Precedence::kTernary >= parentPrecedence;
3735 if (needParens) {
3736 expr.push_back('(');
3737 }
3738 expr += "select(";
3739 expr += this->assembleExpression(*t.ifFalse(), Precedence::kSequence);
3740 expr += ", ";
3741 expr += this->assembleExpression(*t.ifTrue(), Precedence::kSequence);
3742 expr += ", ";
3743
3744 bool isVector = t.type().isVector();
3745 if (isVector) {
3746 // Splat the condition expression into a vector.
3747 expr += String::printf("vec%d<bool>(", t.type().columns());
3748 }
3749 expr += this->assembleExpression(*t.test(), Precedence::kSequence);
3750 if (isVector) {
3751 expr.push_back(')');
3752 }
3753 expr.push_back(')');
3754 if (needParens) {
3755 expr.push_back(')');
3756 }
3757 } else {
3758 // WGSL does not support ternary expressions. Instead, we hoist the expression out into the
3759 // surrounding block, convert it into an if statement, and write the result to a synthesized
3760 // variable. Instead of the original expression, we return that variable.
3761 expr = this->writeScratchVar(t.ifTrue()->type());
3762
3763 std::string testExpr = this->assembleExpression(*t.test(), Precedence::kExpression);
3764 this->write("if ");
3765 this->write(testExpr);
3766 this->writeLine(" {");
3767
3768 ++fIndentation;
3769 std::string trueExpr = this->assembleExpression(*t.ifTrue(), Precedence::kAssignment);
3770 this->write(expr);
3771 this->write(" = ");
3772 this->write(trueExpr);
3773 this->writeLine(";");
3774 --fIndentation;
3775
3776 this->writeLine("} else {");
3777
3778 ++fIndentation;
3779 std::string falseExpr = this->assembleExpression(*t.ifFalse(), Precedence::kAssignment);
3780 this->write(expr);
3781 this->write(" = ");
3782 this->write(falseExpr);
3783 this->writeLine(";");
3784 --fIndentation;
3785
3786 this->writeLine("}");
3787 }
3788 return expr;
3789 }
3790
variablePrefix(const Variable & v)3791 std::string WGSLCodeGenerator::variablePrefix(const Variable& v) {
3792 if (v.storage() == Variable::Storage::kGlobal) {
3793 // If the field refers to a pipeline IO parameter, then we access it via the synthesized IO
3794 // structs. We make an explicit exception for `sk_PointSize` which we declare as a
3795 // placeholder variable in global scope as it is not supported by WebGPU as a pipeline IO
3796 // parameter (see comments in `writeStageOutputStruct`).
3797 if (v.modifierFlags() & ModifierFlag::kIn) {
3798 return "_stageIn.";
3799 }
3800 if (v.modifierFlags() & ModifierFlag::kOut) {
3801 return "(*_stageOut).";
3802 }
3803
3804 // If the field refers to an anonymous-interface-block structure, access it via the
3805 // synthesized `_uniform0` or `_storage1` global.
3806 if (const InterfaceBlock* ib = v.interfaceBlock()) {
3807 const Type& ibType = ib->var()->type().componentType();
3808 if (const std::string* ibName = fInterfaceBlockNameMap.find(&ibType)) {
3809 return *ibName + '.';
3810 }
3811 }
3812
3813 // If the field refers to an top-level uniform, access it via the synthesized
3814 // `_globalUniforms` global. (Note that this should only occur in test code; Skia will
3815 // always put uniforms in an interface block.)
3816 if (is_in_global_uniforms(v)) {
3817 return "_globalUniforms.";
3818 }
3819 }
3820
3821 return "";
3822 }
3823
variableReferenceNameForLValue(const VariableReference & r)3824 std::string WGSLCodeGenerator::variableReferenceNameForLValue(const VariableReference& r) {
3825 const Variable& v = *r.variable();
3826
3827 if (v.storage() == Variable::Storage::kParameter &&
3828 (v.modifierFlags() & ModifierFlag::kOut || v.type().isUnsizedArray())) {
3829 // This is an out-parameter or unsized array parameter; it's pointer-typed, so we need to
3830 // dereference it. We wrap the dereference in parentheses, in case the value is used in an
3831 // access expression later.
3832 return "(*" + this->assembleName(v.mangledName()) + ')';
3833 }
3834
3835 return this->variablePrefix(v) + this->assembleName(v.mangledName());
3836 }
3837
assembleVariableReference(const VariableReference & r)3838 std::string WGSLCodeGenerator::assembleVariableReference(const VariableReference& r) {
3839 // TODO(b/294274678): Correctly handle RTFlip for built-ins.
3840 const Variable& v = *r.variable();
3841
3842 // Insert a conversion expression if this is a built-in variable whose type differs from the
3843 // SkSL.
3844 std::string expr;
3845 std::optional<std::string_view> conversion = needs_builtin_type_conversion(v);
3846 if (conversion.has_value()) {
3847 expr += *conversion;
3848 expr.push_back('(');
3849 }
3850
3851 expr += this->variableReferenceNameForLValue(r);
3852
3853 if (conversion.has_value()) {
3854 expr.push_back(')');
3855 }
3856
3857 return expr;
3858 }
3859
assembleAnyConstructor(const AnyConstructor & c)3860 std::string WGSLCodeGenerator::assembleAnyConstructor(const AnyConstructor& c) {
3861 std::string expr = to_wgsl_type(fContext, c.type());
3862 expr.push_back('(');
3863 auto separator = SkSL::String::Separator();
3864 for (const auto& e : c.argumentSpan()) {
3865 expr += separator();
3866 expr += this->assembleExpression(*e, Precedence::kSequence);
3867 }
3868 expr.push_back(')');
3869 return expr;
3870 }
3871
assembleConstructorCompound(const ConstructorCompound & c)3872 std::string WGSLCodeGenerator::assembleConstructorCompound(const ConstructorCompound& c) {
3873 if (c.type().isVector()) {
3874 return this->assembleConstructorCompoundVector(c);
3875 } else if (c.type().isMatrix()) {
3876 return this->assembleConstructorCompoundMatrix(c);
3877 } else {
3878 fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
3879 return {};
3880 }
3881 }
3882
assembleConstructorCompoundVector(const ConstructorCompound & c)3883 std::string WGSLCodeGenerator::assembleConstructorCompoundVector(const ConstructorCompound& c) {
3884 // WGSL supports constructing vectors from a mix of scalars and vectors but
3885 // not matrices (see https://www.w3.org/TR/WGSL/#type-constructor-expr).
3886 //
3887 // SkSL supports vec4(mat2x2) which we handle specially.
3888 if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
3889 const Expression& arg = *c.argumentSpan().front();
3890 if (arg.type().isMatrix()) {
3891 SkASSERT(arg.type().columns() == 2);
3892 SkASSERT(arg.type().rows() == 2);
3893
3894 std::string matrix = this->writeNontrivialScratchLet(arg, Precedence::kPostfix);
3895 return String::printf("%s(%s[0], %s[1])", to_wgsl_type(fContext, c.type()).c_str(),
3896 matrix.c_str(),
3897 matrix.c_str());
3898 }
3899 }
3900 return this->assembleAnyConstructor(c);
3901 }
3902
assembleConstructorCompoundMatrix(const ConstructorCompound & ctor)3903 std::string WGSLCodeGenerator::assembleConstructorCompoundMatrix(const ConstructorCompound& ctor) {
3904 SkASSERT(ctor.type().isMatrix());
3905
3906 std::string expr = to_wgsl_type(fContext, ctor.type()) + '(';
3907 auto separator = String::Separator();
3908 for (const std::unique_ptr<Expression>& arg : ctor.arguments()) {
3909 SkASSERT(arg->type().isScalar() || arg->type().isVector());
3910
3911 if (arg->type().isScalar()) {
3912 expr += separator();
3913 expr += this->assembleExpression(*arg, Precedence::kSequence);
3914 } else {
3915 std::string inner = this->writeNontrivialScratchLet(*arg, Precedence::kSequence);
3916 int numSlots = arg->type().slotCount();
3917 for (int slot = 0; slot < numSlots; ++slot) {
3918 String::appendf(&expr, "%s%s[%d]", separator().c_str(), inner.c_str(), slot);
3919 }
3920 }
3921 }
3922 return expr + ')';
3923 }
3924
assembleConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c)3925 std::string WGSLCodeGenerator::assembleConstructorDiagonalMatrix(
3926 const ConstructorDiagonalMatrix& c) {
3927 const Type& type = c.type();
3928 SkASSERT(type.isMatrix());
3929 SkASSERT(c.argument()->type().isScalar());
3930
3931 // Evaluate the inner-expression, creating a scratch variable if necessary.
3932 std::string inner = this->writeNontrivialScratchLet(*c.argument(), Precedence::kAssignment);
3933
3934 // Assemble a diagonal-matrix expression.
3935 std::string expr = to_wgsl_type(fContext, type) + '(';
3936 auto separator = String::Separator();
3937 for (int col = 0; col < type.columns(); ++col) {
3938 for (int row = 0; row < type.rows(); ++row) {
3939 expr += separator();
3940 if (col == row) {
3941 expr += inner;
3942 } else {
3943 expr += "0.0";
3944 }
3945 }
3946 }
3947 return expr + ')';
3948 }
3949
assembleConstructorMatrixResize(const ConstructorMatrixResize & ctor)3950 std::string WGSLCodeGenerator::assembleConstructorMatrixResize(
3951 const ConstructorMatrixResize& ctor) {
3952 std::string source = this->writeNontrivialScratchLet(*ctor.argument(), Precedence::kSequence);
3953 int columns = ctor.type().columns();
3954 int rows = ctor.type().rows();
3955 int sourceColumns = ctor.argument()->type().columns();
3956 int sourceRows = ctor.argument()->type().rows();
3957 auto separator = String::Separator();
3958 std::string expr = to_wgsl_type(fContext, ctor.type()) + '(';
3959
3960 for (int c = 0; c < columns; ++c) {
3961 for (int r = 0; r < rows; ++r) {
3962 expr += separator();
3963 if (c < sourceColumns && r < sourceRows) {
3964 String::appendf(&expr, "%s[%d][%d]", source.c_str(), c, r);
3965 } else if (r == c) {
3966 expr += "1.0";
3967 } else {
3968 expr += "0.0";
3969 }
3970 }
3971 }
3972
3973 return expr + ')';
3974 }
3975
assembleEqualityExpression(const Type & left,const std::string & leftName,const Type & right,const std::string & rightName,Operator op,Precedence parentPrecedence)3976 std::string WGSLCodeGenerator::assembleEqualityExpression(const Type& left,
3977 const std::string& leftName,
3978 const Type& right,
3979 const std::string& rightName,
3980 Operator op,
3981 Precedence parentPrecedence) {
3982 SkASSERT(op.kind() == OperatorKind::EQEQ || op.kind() == OperatorKind::NEQ);
3983
3984 std::string expr;
3985 bool isEqual = (op.kind() == Operator::Kind::EQEQ);
3986 const char* const combiner = isEqual ? " && " : " || ";
3987
3988 if (left.isMatrix()) {
3989 // Each matrix column must be compared as if it were an individual vector.
3990 SkASSERT(right.isMatrix());
3991 SkASSERT(left.rows() == right.rows());
3992 SkASSERT(left.columns() == right.columns());
3993 int columns = left.columns();
3994 const Type& vecType = left.columnType(fContext);
3995 const char* separator = "(";
3996 for (int index = 0; index < columns; ++index) {
3997 expr += separator;
3998 std::string suffix = '[' + std::to_string(index) + ']';
3999 expr += this->assembleEqualityExpression(vecType, leftName + suffix,
4000 vecType, rightName + suffix,
4001 op, Precedence::kParentheses);
4002 separator = combiner;
4003 }
4004 return expr + ')';
4005 }
4006
4007 if (left.isArray()) {
4008 SkASSERT(right.matches(left));
4009 const Type& indexedType = left.componentType();
4010 const char* separator = "(";
4011 for (int index = 0; index < left.columns(); ++index) {
4012 expr += separator;
4013 std::string suffix = '[' + std::to_string(index) + ']';
4014 expr += this->assembleEqualityExpression(indexedType, leftName + suffix,
4015 indexedType, rightName + suffix,
4016 op, Precedence::kParentheses);
4017 separator = combiner;
4018 }
4019 return expr + ')';
4020 }
4021
4022 if (left.isStruct()) {
4023 // Recursively compare every field in the struct.
4024 SkASSERT(right.matches(left));
4025 SkSpan<const Field> fields = left.fields();
4026
4027 const char* separator = "(";
4028 for (const Field& field : fields) {
4029 expr += separator;
4030 expr += this->assembleEqualityExpression(
4031 *field.fType, leftName + '.' + this->assembleName(field.fName),
4032 *field.fType, rightName + '.' + this->assembleName(field.fName),
4033 op, Precedence::kParentheses);
4034 separator = combiner;
4035 }
4036 return expr + ')';
4037 }
4038
4039 if (left.isVector()) {
4040 // Compare vectors via `all(x == y)` or `any(x != y)`.
4041 SkASSERT(right.isVector());
4042 SkASSERT(left.slotCount() == right.slotCount());
4043
4044 expr += isEqual ? "all(" : "any(";
4045 expr += leftName;
4046 expr += operator_name(op);
4047 expr += rightName;
4048 return expr + ')';
4049 }
4050
4051 // Compare scalars via `x == y`.
4052 SkASSERT(right.isScalar());
4053 if (parentPrecedence < Precedence::kSequence) {
4054 expr = '(';
4055 }
4056 expr += leftName;
4057 expr += operator_name(op);
4058 expr += rightName;
4059 if (parentPrecedence < Precedence::kSequence) {
4060 expr += ')';
4061 }
4062 return expr;
4063 }
4064
assembleEqualityExpression(const Expression & left,const Expression & right,Operator op,Precedence parentPrecedence)4065 std::string WGSLCodeGenerator::assembleEqualityExpression(const Expression& left,
4066 const Expression& right,
4067 Operator op,
4068 Precedence parentPrecedence) {
4069 std::string leftName, rightName;
4070 if (left.type().isScalar() || left.type().isVector()) {
4071 // WGSL supports scalar and vector comparisons natively. We know the expressions will only
4072 // be emitted once, so there isn't a benefit to creating a let-declaration.
4073 leftName = this->assembleExpression(left, Precedence::kParentheses);
4074 rightName = this->assembleExpression(right, Precedence::kParentheses);
4075 } else {
4076 leftName = this->writeNontrivialScratchLet(left, Precedence::kAssignment);
4077 rightName = this->writeNontrivialScratchLet(right, Precedence::kAssignment);
4078 }
4079 return this->assembleEqualityExpression(left.type(), leftName, right.type(), rightName,
4080 op, parentPrecedence);
4081 }
4082
writeProgramElement(const ProgramElement & e)4083 void WGSLCodeGenerator::writeProgramElement(const ProgramElement& e) {
4084 switch (e.kind()) {
4085 case ProgramElement::Kind::kExtension:
4086 // TODO(skia:13092): WGSL supports extensions via the "enable" directive
4087 // (https://www.w3.org/TR/WGSL/#enable-extensions-sec ). While we could easily emit this
4088 // directive, we should first ensure that all possible SkSL extension names are
4089 // converted to their appropriate WGSL extension.
4090 break;
4091 case ProgramElement::Kind::kGlobalVar:
4092 this->writeGlobalVarDeclaration(e.as<GlobalVarDeclaration>());
4093 break;
4094 case ProgramElement::Kind::kInterfaceBlock:
4095 // All interface block declarations are handled explicitly as the "program header" in
4096 // generateCode().
4097 break;
4098 case ProgramElement::Kind::kStructDefinition:
4099 this->writeStructDefinition(e.as<StructDefinition>());
4100 break;
4101 case ProgramElement::Kind::kFunctionPrototype:
4102 // A WGSL function declaration must contain its body and the function name is in scope
4103 // for the entire program (see https://www.w3.org/TR/WGSL/#function-declaration and
4104 // https://www.w3.org/TR/WGSL/#declaration-and-scope).
4105 //
4106 // As such, we don't emit function prototypes.
4107 break;
4108 case ProgramElement::Kind::kFunction:
4109 this->writeFunction(e.as<FunctionDefinition>());
4110 break;
4111 case ProgramElement::Kind::kModifiers:
4112 this->writeModifiersDeclaration(e.as<ModifiersDeclaration>());
4113 break;
4114 default:
4115 SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
4116 break;
4117 }
4118 }
4119
writeTextureOrSampler(const Variable & var,int bindingLocation,std::string_view suffix,std::string_view wgslType)4120 void WGSLCodeGenerator::writeTextureOrSampler(const Variable& var,
4121 int bindingLocation,
4122 std::string_view suffix,
4123 std::string_view wgslType) {
4124 if (var.type().dimensions() != SpvDim2D) {
4125 // Skia currently only uses 2D textures.
4126 fContext.fErrors->error(var.varDeclaration()->position(), "unsupported texture dimensions");
4127 return;
4128 }
4129
4130 this->write("@group(");
4131 this->write(std::to_string(std::max(0, var.layout().fSet)));
4132 this->write(") @binding(");
4133 this->write(std::to_string(bindingLocation));
4134 this->write(") var ");
4135 this->write(this->assembleName(var.mangledName()));
4136 this->write(suffix);
4137 this->write(": ");
4138 this->write(wgslType);
4139 this->writeLine(";");
4140 }
4141
writeGlobalVarDeclaration(const GlobalVarDeclaration & d)4142 void WGSLCodeGenerator::writeGlobalVarDeclaration(const GlobalVarDeclaration& d) {
4143 const VarDeclaration& decl = d.varDeclaration();
4144 const Variable& var = *decl.var();
4145 if ((var.modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) ||
4146 is_in_global_uniforms(var)) {
4147 // Pipeline stage I/O parameters and top-level (non-block) uniforms are handled specially
4148 // in generateCode().
4149 return;
4150 }
4151
4152 const Type::TypeKind varKind = var.type().typeKind();
4153 if (varKind == Type::TypeKind::kSampler) {
4154 // If the sampler binding was unassigned, provide a scratch value; this will make
4155 // golden-output tests pass, but will not actually be usable for drawing.
4156 int samplerLocation = var.layout().fSampler >= 0 ? var.layout().fSampler
4157 : 10000 + fScratchCount++;
4158 this->writeTextureOrSampler(var, samplerLocation, kSamplerSuffix, "sampler");
4159
4160 // If the texture binding was unassigned, provide a scratch value (for golden-output tests).
4161 int textureLocation = var.layout().fTexture >= 0 ? var.layout().fTexture
4162 : 10000 + fScratchCount++;
4163 this->writeTextureOrSampler(var, textureLocation, kTextureSuffix, "texture_2d<f32>");
4164 return;
4165 }
4166
4167 if (varKind == Type::TypeKind::kTexture) {
4168 // If a binding location was unassigned, provide a scratch value (for golden-output tests).
4169 int textureLocation = var.layout().fBinding >= 0 ? var.layout().fBinding
4170 : 10000 + fScratchCount++;
4171 // For a texture without an associated sampler, we don't apply a suffix.
4172 this->writeTextureOrSampler(var, textureLocation, /*suffix=*/"",
4173 to_wgsl_type(fContext, var.type(), &var.layout()));
4174 return;
4175 }
4176
4177 std::string initializer;
4178 if (decl.value()) {
4179 // We assume here that the initial-value expression will not emit any helper statements.
4180 // Initial-value expressions are required to pass IsConstantExpression, which limits the
4181 // blast radius to constructors, literals, and other constant values/variables.
4182 initializer += " = ";
4183 initializer += this->assembleExpression(*decl.value(), Precedence::kAssignment);
4184 }
4185
4186 if (var.modifierFlags().isConst()) {
4187 this->write("const ");
4188 } else if (var.modifierFlags().isWorkgroup()) {
4189 this->write("var<workgroup> ");
4190 } else if (var.modifierFlags().isPixelLocal()) {
4191 this->write("var<pixel_local> ");
4192 } else {
4193 this->write("var<private> ");
4194 }
4195 this->write(this->assembleName(var.mangledName()));
4196 this->write(": " + to_wgsl_type(fContext, var.type(), &var.layout()));
4197 this->write(initializer);
4198 this->writeLine(";");
4199 }
4200
writeStructDefinition(const StructDefinition & s)4201 void WGSLCodeGenerator::writeStructDefinition(const StructDefinition& s) {
4202 const Type& type = s.type();
4203 this->writeLine("struct " + type.displayName() + " {");
4204 this->writeFields(type.fields(), /*memoryLayout=*/nullptr);
4205 this->writeLine("};");
4206 }
4207
writeModifiersDeclaration(const ModifiersDeclaration & modifiers)4208 void WGSLCodeGenerator::writeModifiersDeclaration(const ModifiersDeclaration& modifiers) {
4209 LayoutFlags flags = modifiers.layout().fFlags;
4210 flags &= ~(LayoutFlag::kLocalSizeX | LayoutFlag::kLocalSizeY | LayoutFlag::kLocalSizeZ);
4211 if (flags != LayoutFlag::kNone) {
4212 fContext.fErrors->error(modifiers.position(), "unsupported declaration");
4213 return;
4214 }
4215
4216 if (modifiers.layout().fLocalSizeX >= 0) {
4217 fLocalSizeX = modifiers.layout().fLocalSizeX;
4218 }
4219 if (modifiers.layout().fLocalSizeY >= 0) {
4220 fLocalSizeY = modifiers.layout().fLocalSizeY;
4221 }
4222 if (modifiers.layout().fLocalSizeZ >= 0) {
4223 fLocalSizeZ = modifiers.layout().fLocalSizeZ;
4224 }
4225 }
4226
writeFields(SkSpan<const Field> fields,const MemoryLayout * memoryLayout)4227 void WGSLCodeGenerator::writeFields(SkSpan<const Field> fields, const MemoryLayout* memoryLayout) {
4228 fIndentation++;
4229
4230 // TODO(skia:14370): array uniforms may need manual fixup for std140 padding. (Those uniforms
4231 // will also need special handling when they are accessed, or passed to functions.)
4232 for (size_t index = 0; index < fields.size(); ++index) {
4233 const Field& field = fields[index];
4234 if (memoryLayout && !memoryLayout->isSupported(*field.fType)) {
4235 // Reject types that aren't supported by the memory layout.
4236 fContext.fErrors->error(field.fPosition, "type '" + std::string(field.fType->name()) +
4237 "' is not permitted here");
4238 return;
4239 }
4240
4241 // Prepend @size(n) to enforce the offsets from the SkSL layout. (This is effectively
4242 // a gadget that we can use to insert padding between elements.)
4243 if (index < fields.size() - 1) {
4244 int thisFieldOffset = field.fLayout.fOffset;
4245 int nextFieldOffset = fields[index + 1].fLayout.fOffset;
4246 if (index == 0 && thisFieldOffset > 0) {
4247 fContext.fErrors->error(field.fPosition, "field must have an offset of zero");
4248 return;
4249 }
4250 if (thisFieldOffset >= 0 && nextFieldOffset > thisFieldOffset) {
4251 this->write("@size(");
4252 this->write(std::to_string(nextFieldOffset - thisFieldOffset));
4253 this->write(") ");
4254 }
4255 }
4256
4257 this->write(this->assembleName(field.fName));
4258 this->write(": ");
4259 if (const FieldPolyfillInfo* info = fFieldPolyfillMap.find(&field)) {
4260 if (info->fIsArray) {
4261 // This properly handles arrays of matrices, as well as arrays of other primitives.
4262 SkASSERT(field.fType->isArray());
4263 this->write("array<_skArrayElement_");
4264 this->write(field.fType->abbreviatedName());
4265 this->write(", ");
4266 this->write(std::to_string(field.fType->columns()));
4267 this->write(">");
4268 } else if (info->fIsMatrix) {
4269 this->write("_skMatrix");
4270 this->write(std::to_string(field.fType->columns()));
4271 this->write(std::to_string(field.fType->rows()));
4272 } else {
4273 SkDEBUGFAILF("need polyfill for %s", info->fReplacementName.c_str());
4274 }
4275 } else {
4276 this->write(to_wgsl_type(fContext, *field.fType, &field.fLayout));
4277 }
4278 this->writeLine(",");
4279 }
4280
4281 fIndentation--;
4282 }
4283
writeEnables()4284 void WGSLCodeGenerator::writeEnables() {
4285 this->writeLine("diagnostic(off, derivative_uniformity);");
4286 this->writeLine("diagnostic(off, chromium.unreachable_code);");
4287
4288 if (fRequirements.fPixelLocalExtension) {
4289 this->writeLine("enable chromium_experimental_pixel_local;");
4290 }
4291 if (fProgram.fInterface.fUseLastFragColor) {
4292 this->writeLine("enable chromium_experimental_framebuffer_fetch;");
4293 }
4294 if (fProgram.fInterface.fOutputSecondaryColor) {
4295 this->writeLine("enable dual_source_blending;");
4296 }
4297 }
4298
needsStageInputStruct() const4299 bool WGSLCodeGenerator::needsStageInputStruct() const {
4300 // It is illegal to declare a struct with no members; we can't emit a placeholder empty stage
4301 // input struct.
4302 return !fPipelineInputs.empty();
4303 }
4304
writeStageInputStruct()4305 void WGSLCodeGenerator::writeStageInputStruct() {
4306 if (!this->needsStageInputStruct()) {
4307 return;
4308 }
4309
4310 std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4311 SkASSERT(!structNamePrefix.empty());
4312
4313 this->write("struct ");
4314 this->write(structNamePrefix);
4315 this->writeLine("In {");
4316 fIndentation++;
4317
4318 for (const Variable* v : fPipelineInputs) {
4319 if (v->type().isInterfaceBlock()) {
4320 for (const Field& f : v->type().fields()) {
4321 this->writePipelineIODeclaration(f.fLayout, *f.fType, f.fModifierFlags, f.fName,
4322 Delimiter::kComma);
4323 }
4324 } else {
4325 this->writePipelineIODeclaration(v->layout(), v->type(), v->modifierFlags(),
4326 v->mangledName(), Delimiter::kComma);
4327 }
4328 }
4329
4330 fIndentation--;
4331 this->writeLine("};");
4332 }
4333
needsStageOutputStruct() const4334 bool WGSLCodeGenerator::needsStageOutputStruct() const {
4335 // It is illegal to declare a struct with no members. However, vertex programs will _always_
4336 // have an output stage in WGSL, because the spec requires them to emit `@builtin(position)`.
4337 // So we always synthesize a reference to `sk_Position` even if the program doesn't need it.
4338 return !fPipelineOutputs.empty() || ProgramConfig::IsVertex(fProgram.fConfig->fKind);
4339 }
4340
writeStageOutputStruct()4341 void WGSLCodeGenerator::writeStageOutputStruct() {
4342 if (!this->needsStageOutputStruct()) {
4343 return;
4344 }
4345
4346 std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4347 SkASSERT(!structNamePrefix.empty());
4348
4349 this->write("struct ");
4350 this->write(structNamePrefix);
4351 this->writeLine("Out {");
4352 fIndentation++;
4353
4354 bool declaredPositionBuiltin = false;
4355 bool requiresPointSizeBuiltin = false;
4356 for (const Variable* v : fPipelineOutputs) {
4357 if (v->type().isInterfaceBlock()) {
4358 for (const auto& f : v->type().fields()) {
4359 this->writePipelineIODeclaration(f.fLayout, *f.fType, f.fModifierFlags, f.fName,
4360 Delimiter::kComma);
4361 if (f.fLayout.fBuiltin == SK_POSITION_BUILTIN) {
4362 declaredPositionBuiltin = true;
4363 } else if (f.fLayout.fBuiltin == SK_POINTSIZE_BUILTIN) {
4364 // sk_PointSize is explicitly not supported by `builtin_from_sksl_name` so
4365 // writePipelineIODeclaration will never write it. We mark it here if the
4366 // declaration is needed so we can synthesize it below.
4367 requiresPointSizeBuiltin = true;
4368 }
4369 }
4370 } else {
4371 this->writePipelineIODeclaration(v->layout(), v->type(), v->modifierFlags(),
4372 v->mangledName(), Delimiter::kComma);
4373 }
4374 }
4375
4376 // A vertex program must include the `position` builtin in its entrypoint's return type.
4377 const bool positionBuiltinRequired = ProgramConfig::IsVertex(fProgram.fConfig->fKind);
4378 if (positionBuiltinRequired && !declaredPositionBuiltin) {
4379 this->writeLine("@builtin(position) sk_Position: vec4<f32>,");
4380 }
4381
4382 fIndentation--;
4383 this->writeLine("};");
4384
4385 // In WebGPU/WGSL, the vertex stage does not support a point-size output and the size
4386 // of a point primitive is always 1 pixel (see https://github.com/gpuweb/gpuweb/issues/332).
4387 //
4388 // There isn't anything we can do to emulate this correctly at this stage so we synthesize a
4389 // placeholder global variable that has no effect. Programs should not rely on sk_PointSize when
4390 // using the Dawn backend.
4391 if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) && requiresPointSizeBuiltin) {
4392 this->writeLine("/* unsupported */ var<private> sk_PointSize: f32;");
4393 }
4394 }
4395
prepareUniformPolyfillsForInterfaceBlock(const InterfaceBlock * interfaceBlock,std::string_view instanceName,MemoryLayout::Standard nativeLayout)4396 void WGSLCodeGenerator::prepareUniformPolyfillsForInterfaceBlock(
4397 const InterfaceBlock* interfaceBlock,
4398 std::string_view instanceName,
4399 MemoryLayout::Standard nativeLayout) {
4400 SkSL::MemoryLayout std140(MemoryLayout::Standard::k140);
4401 SkSL::MemoryLayout native(nativeLayout);
4402
4403 const Type& structType = interfaceBlock->var()->type().componentType();
4404 for (const Field& field : structType.fields()) {
4405 const Type* type = field.fType;
4406 bool needsArrayPolyfill = false;
4407 bool needsMatrixPolyfill = false;
4408
4409 auto isPolyfillableMatrixType = [&](const Type* type) {
4410 return type->isMatrix() && std140.stride(*type) != native.stride(*type);
4411 };
4412
4413 if (isPolyfillableMatrixType(type)) {
4414 // Matrices will be represented as 16-byte aligned arrays in std140, and reconstituted
4415 // into proper matrices as they are later accessed. We need to synthesize polyfill.
4416 needsMatrixPolyfill = true;
4417 } else if (type->isArray() && !type->isUnsizedArray() &&
4418 !type->componentType().isOpaque()) {
4419 const Type* innerType = &type->componentType();
4420 if (isPolyfillableMatrixType(innerType)) {
4421 // Use a polyfill when the array contains a matrix that requires polyfill.
4422 needsArrayPolyfill = true;
4423 needsMatrixPolyfill = true;
4424 } else if (native.size(*innerType) < 16) {
4425 // Use a polyfill when the array elements are smaller than 16 bytes, since std140
4426 // will pad elements to a 16-byte stride.
4427 needsArrayPolyfill = true;
4428 }
4429 }
4430
4431 if (needsArrayPolyfill || needsMatrixPolyfill) {
4432 // Add a polyfill for this matrix type.
4433 FieldPolyfillInfo info;
4434 info.fInterfaceBlock = interfaceBlock;
4435 info.fReplacementName = "_skUnpacked_" + std::string(instanceName) + '_' +
4436 this->assembleName(field.fName);
4437 info.fIsArray = needsArrayPolyfill;
4438 info.fIsMatrix = needsMatrixPolyfill;
4439 fFieldPolyfillMap.set(&field, info);
4440 }
4441 }
4442 }
4443
writeUniformsAndBuffers()4444 void WGSLCodeGenerator::writeUniformsAndBuffers() {
4445 for (const ProgramElement* e : fProgram.elements()) {
4446 // Iterate through the interface blocks.
4447 if (!e->is<InterfaceBlock>()) {
4448 continue;
4449 }
4450 const InterfaceBlock& ib = e->as<InterfaceBlock>();
4451
4452 // Determine if this interface block holds uniforms, buffers, or something else (skip it).
4453 std::string_view addressSpace;
4454 std::string_view accessMode;
4455 MemoryLayout::Standard nativeLayout;
4456 if (ib.var()->modifierFlags().isUniform()) {
4457 addressSpace = "uniform";
4458 nativeLayout = MemoryLayout::Standard::kWGSLUniform_Base;
4459 } else if (ib.var()->modifierFlags().isBuffer()) {
4460 addressSpace = "storage";
4461 nativeLayout = MemoryLayout::Standard::kWGSLStorage_Base;
4462 accessMode = ib.var()->modifierFlags().isReadOnly() ? ", read" : ", read_write";
4463 } else {
4464 continue;
4465 }
4466
4467 // If we have an anonymous interface block, assign a name like `_uniform0` or `_storage1`.
4468 std::string instanceName;
4469 if (ib.instanceName().empty()) {
4470 instanceName = "_" + std::string(addressSpace) + std::to_string(fScratchCount++);
4471 fInterfaceBlockNameMap[&ib.var()->type().componentType()] = instanceName;
4472 } else {
4473 instanceName = std::string(ib.instanceName());
4474 }
4475
4476 this->prepareUniformPolyfillsForInterfaceBlock(&ib, instanceName, nativeLayout);
4477
4478 // Create a struct to hold all of the fields from this InterfaceBlock.
4479 SkASSERT(!ib.typeName().empty());
4480 this->write("struct ");
4481 this->write(ib.typeName());
4482 this->writeLine(" {");
4483
4484 // Find the struct type and fields used by this interface block.
4485 const Type& ibType = ib.var()->type().componentType();
4486 SkASSERT(ibType.isStruct());
4487
4488 SkSpan<const Field> ibFields = ibType.fields();
4489 SkASSERT(!ibFields.empty());
4490
4491 MemoryLayout layout(MemoryLayout::Standard::k140);
4492 this->writeFields(ibFields, &layout);
4493 this->writeLine("};");
4494 this->write("@group(");
4495 this->write(std::to_string(std::max(0, ib.var()->layout().fSet)));
4496 this->write(") @binding(");
4497 this->write(std::to_string(std::max(0, ib.var()->layout().fBinding)));
4498 this->write(") var<");
4499 this->write(addressSpace);
4500 this->write(accessMode);
4501 this->write("> ");
4502 this->write(instanceName);
4503 this->write(" : ");
4504 this->write(to_wgsl_type(fContext, ib.var()->type(), &ib.var()->layout()));
4505 this->writeLine(";");
4506 }
4507 }
4508
writeNonBlockUniformsForTests()4509 void WGSLCodeGenerator::writeNonBlockUniformsForTests() {
4510 bool declaredUniformsStruct = false;
4511
4512 for (const ProgramElement* e : fProgram.elements()) {
4513 if (e->is<GlobalVarDeclaration>()) {
4514 const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
4515 const Variable& var = *decls.varDeclaration().var();
4516 if (is_in_global_uniforms(var)) {
4517 if (!declaredUniformsStruct) {
4518 this->write("struct _GlobalUniforms {\n");
4519 declaredUniformsStruct = true;
4520 }
4521 this->write(" ");
4522 this->writeVariableDecl(var.layout(), var.type(), var.mangledName(),
4523 Delimiter::kComma);
4524 }
4525 }
4526 }
4527 if (declaredUniformsStruct) {
4528 int binding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
4529 int set = fProgram.fConfig->fSettings.fDefaultUniformSet;
4530 this->write("};\n");
4531 this->write("@binding(" + std::to_string(binding) + ") ");
4532 this->write("@group(" + std::to_string(set) + ") ");
4533 this->writeLine("var<uniform> _globalUniforms: _GlobalUniforms;");
4534 }
4535 }
4536
functionDependencyArgs(const FunctionDeclaration & f)4537 std::string WGSLCodeGenerator::functionDependencyArgs(const FunctionDeclaration& f) {
4538 WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&f);
4539 std::string args;
4540 if (deps && *deps) {
4541 const char* separator = "";
4542 if (*deps & WGSLFunctionDependency::kPipelineInputs) {
4543 args += "_stageIn";
4544 separator = ", ";
4545 }
4546 if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
4547 args += separator;
4548 args += "_stageOut";
4549 }
4550 }
4551 return args;
4552 }
4553
writeFunctionDependencyParams(const FunctionDeclaration & f)4554 bool WGSLCodeGenerator::writeFunctionDependencyParams(const FunctionDeclaration& f) {
4555 WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&f);
4556 if (!deps || !*deps) {
4557 return false;
4558 }
4559
4560 std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4561 if (structNamePrefix.empty()) {
4562 return false;
4563 }
4564 const char* separator = "";
4565 if (*deps & WGSLFunctionDependency::kPipelineInputs) {
4566 this->write("_stageIn: ");
4567 separator = ", ";
4568 this->write(structNamePrefix);
4569 this->write("In");
4570 }
4571 if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
4572 this->write(separator);
4573 this->write("_stageOut: ptr<function, ");
4574 this->write(structNamePrefix);
4575 this->write("Out>");
4576 }
4577 return true;
4578 }
4579
ToWGSL(Program & program,const ShaderCaps * caps,OutputStream & out,PrettyPrint pp,IncludeSyntheticCode isc,ValidateWGSLProc validateWGSL)4580 bool ToWGSL(Program& program,
4581 const ShaderCaps* caps,
4582 OutputStream& out,
4583 PrettyPrint pp,
4584 IncludeSyntheticCode isc,
4585 ValidateWGSLProc validateWGSL) {
4586 TRACE_EVENT0("skia.shaders", "SkSL::ToWGSL");
4587 SkASSERT(caps != nullptr);
4588
4589 program.fContext->fErrors->setSource(*program.fSource);
4590 bool result;
4591 if (validateWGSL) {
4592 StringStream wgsl;
4593 WGSLCodeGenerator cg(program.fContext.get(), caps, &program, &wgsl, pp, isc);
4594 result = cg.generateCode();
4595 if (result) {
4596 std::string_view wgslBytes = wgsl.str();
4597 std::string warnings;
4598 result = validateWGSL(*program.fContext->fErrors, wgslBytes, &warnings);
4599 if (!warnings.empty()) {
4600 out.writeText("/* Tint reported warnings. */\n\n");
4601 }
4602 out.write(wgslBytes.data(), wgslBytes.size());
4603 }
4604 } else {
4605 WGSLCodeGenerator cg(program.fContext.get(), caps, &program, &out, pp, isc);
4606 result = cg.generateCode();
4607 }
4608 program.fContext->fErrors->setSource(std::string_view());
4609
4610 return result;
4611 }
4612
ToWGSL(Program & program,const ShaderCaps * caps,OutputStream & out)4613 bool ToWGSL(Program& program, const ShaderCaps* caps, OutputStream& out) {
4614 #if defined(SK_DEBUG)
4615 constexpr PrettyPrint defaultPrintOpts = PrettyPrint::kYes;
4616 #else
4617 constexpr PrettyPrint defaultPrintOpts = PrettyPrint::kNo;
4618 #endif
4619 return ToWGSL(program, caps, out, defaultPrintOpts, IncludeSyntheticCode::kNo, nullptr);
4620 }
4621
ToWGSL(Program & program,const ShaderCaps * caps,std::string * out)4622 bool ToWGSL(Program& program, const ShaderCaps* caps, std::string* out) {
4623 StringStream buffer;
4624 if (!ToWGSL(program, caps, buffer)) {
4625 return false;
4626 }
4627 *out = buffer.str();
4628 return true;
4629 }
4630
4631 } // namespace SkSL
4632