xref: /aosp_15_r20/external/skia/src/sksl/codegen/SkSLWGSLCodeGenerator.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2022 Google Inc.
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/codegen/SkSLWGSLCodeGenerator.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/core/SkTypes.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/private/base/SkTo.h"
14 #include "src/base/SkEnumBitMask.h"
15 #include "src/base/SkStringView.h"
16 #include "src/core/SkTHash.h"
17 #include "src/core/SkTraceEvent.h"
18 #include "src/sksl/SkSLAnalysis.h"
19 #include "src/sksl/SkSLBuiltinTypes.h"
20 #include "src/sksl/SkSLCompiler.h"
21 #include "src/sksl/SkSLConstantFolder.h"
22 #include "src/sksl/SkSLContext.h"
23 #include "src/sksl/SkSLDefines.h"
24 #include "src/sksl/SkSLErrorReporter.h"
25 #include "src/sksl/SkSLIntrinsicList.h"
26 #include "src/sksl/SkSLMemoryLayout.h"
27 #include "src/sksl/SkSLOperator.h"
28 #include "src/sksl/SkSLOutputStream.h"
29 #include "src/sksl/SkSLPosition.h"
30 #include "src/sksl/SkSLProgramSettings.h"
31 #include "src/sksl/SkSLString.h"
32 #include "src/sksl/SkSLStringStream.h"
33 #include "src/sksl/SkSLUtil.h"
34 #include "src/sksl/analysis/SkSLProgramUsage.h"
35 #include "src/sksl/analysis/SkSLProgramVisitor.h"
36 #include "src/sksl/codegen/SkSLCodeGenTypes.h"
37 #include "src/sksl/codegen/SkSLCodeGenerator.h"
38 #include "src/sksl/ir/SkSLBinaryExpression.h"
39 #include "src/sksl/ir/SkSLBlock.h"
40 #include "src/sksl/ir/SkSLConstructor.h"
41 #include "src/sksl/ir/SkSLConstructorArrayCast.h"
42 #include "src/sksl/ir/SkSLConstructorCompound.h"
43 #include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
44 #include "src/sksl/ir/SkSLConstructorMatrixResize.h"
45 #include "src/sksl/ir/SkSLDoStatement.h"
46 #include "src/sksl/ir/SkSLExpression.h"
47 #include "src/sksl/ir/SkSLExpressionStatement.h"
48 #include "src/sksl/ir/SkSLFieldAccess.h"
49 #include "src/sksl/ir/SkSLFieldSymbol.h"
50 #include "src/sksl/ir/SkSLForStatement.h"
51 #include "src/sksl/ir/SkSLFunctionCall.h"
52 #include "src/sksl/ir/SkSLFunctionDeclaration.h"
53 #include "src/sksl/ir/SkSLFunctionDefinition.h"
54 #include "src/sksl/ir/SkSLIRHelpers.h"
55 #include "src/sksl/ir/SkSLIRNode.h"
56 #include "src/sksl/ir/SkSLIfStatement.h"
57 #include "src/sksl/ir/SkSLIndexExpression.h"
58 #include "src/sksl/ir/SkSLInterfaceBlock.h"
59 #include "src/sksl/ir/SkSLLayout.h"
60 #include "src/sksl/ir/SkSLLiteral.h"
61 #include "src/sksl/ir/SkSLModifierFlags.h"
62 #include "src/sksl/ir/SkSLModifiersDeclaration.h"
63 #include "src/sksl/ir/SkSLPostfixExpression.h"
64 #include "src/sksl/ir/SkSLPrefixExpression.h"
65 #include "src/sksl/ir/SkSLProgram.h"
66 #include "src/sksl/ir/SkSLProgramElement.h"
67 #include "src/sksl/ir/SkSLReturnStatement.h"
68 #include "src/sksl/ir/SkSLSetting.h"
69 #include "src/sksl/ir/SkSLStatement.h"
70 #include "src/sksl/ir/SkSLStructDefinition.h"
71 #include "src/sksl/ir/SkSLSwitchCase.h"
72 #include "src/sksl/ir/SkSLSwitchStatement.h"
73 #include "src/sksl/ir/SkSLSwizzle.h"
74 #include "src/sksl/ir/SkSLTernaryExpression.h"
75 #include "src/sksl/ir/SkSLType.h"
76 #include "src/sksl/ir/SkSLVarDeclarations.h"
77 #include "src/sksl/ir/SkSLVariable.h"
78 #include "src/sksl/ir/SkSLVariableReference.h"
79 #include "src/sksl/spirv.h"
80 #include "src/sksl/transform/SkSLTransform.h"
81 
82 #include <algorithm>
83 #include <cstddef>
84 #include <cstdint>
85 #include <initializer_list>
86 #include <iterator>
87 #include <memory>
88 #include <optional>
89 #include <string>
90 #include <string_view>
91 #include <utility>
92 
93 using namespace skia_private;
94 
95 namespace {
96 
97 // Represents a function's dependencies that are not accessible in global scope. For instance,
98 // pipeline stage input and output parameters must be passed in as an argument.
99 //
100 // This is a bitmask enum. (It would be inside `class WGSLCodeGenerator`, but this leads to build
101 // errors in MSVC.)
102 enum class WGSLFunctionDependency : uint8_t {
103     kNone = 0,
104     kPipelineInputs  = 1 << 0,
105     kPipelineOutputs = 1 << 1,
106 };
107 using WGSLFunctionDependencies = SkEnumBitMask<WGSLFunctionDependency>;
108 
109 SK_MAKE_BITMASK_OPS(WGSLFunctionDependency)
110 
111 }  // namespace
112 
113 namespace SkSL {
114 
115 class WGSLCodeGenerator : public CodeGenerator {
116 public:
117     // See https://www.w3.org/TR/WGSL/#builtin-values
118     enum class Builtin {
119         // Vertex stage:
120         kVertexIndex,    // input
121         kInstanceIndex,  // input
122         kPosition,       // output, fragment stage input
123 
124         // Fragment stage:
125         kLastFragColor,  // input
126         kFrontFacing,    // input
127         kSampleIndex,    // input
128         kFragDepth,      // output
129         kSampleMaskIn,   // input
130         kSampleMask,     // output
131 
132         // Compute stage:
133         kLocalInvocationId,     // input
134         kLocalInvocationIndex,  // input
135         kGlobalInvocationId,    // input
136         kWorkgroupId,           // input
137         kNumWorkgroups,         // input
138     };
139 
140     // Variable declarations can be terminated by:
141     //   - comma (","), e.g. in struct member declarations or function parameters
142     //   - semicolon (";"), e.g. in function scope variables
143     // A "none" option is provided to skip the delimiter when not needed, e.g. at the end of a list
144     // of declarations.
145     enum class Delimiter {
146         kComma,
147         kSemicolon,
148         kNone,
149     };
150 
151     struct ProgramRequirements {
152         using DepsMap = skia_private::THashMap<const FunctionDeclaration*,
153                                                WGSLFunctionDependencies>;
154 
155         // Mappings used to synthesize function parameters according to dependencies on pipeline
156         // input/output variables.
157         DepsMap fDependencies;
158 
159         // These flags track extensions that will need to be enabled.
160         bool fPixelLocalExtension = false;
161     };
162 
WGSLCodeGenerator(const Context * context,const ShaderCaps * caps,const Program * program,OutputStream * out,PrettyPrint pp,IncludeSyntheticCode isc)163     WGSLCodeGenerator(const Context* context,
164                       const ShaderCaps* caps,
165                       const Program* program,
166                       OutputStream* out,
167                       PrettyPrint pp,
168                       IncludeSyntheticCode isc)
169             : CodeGenerator(context, caps, program, out)
170             , fPrettyPrint(pp)
171             , fGenSyntheticCode(isc) {}
172 
173     bool generateCode() override;
174 
175 private:
176     using Precedence = OperatorPrecedence;
177 
178     // Called by generateCode() as the first step.
179     void preprocessProgram();
180 
181     // Write output content while correctly handling indentation.
182     void write(std::string_view s);
183     void writeLine(std::string_view s = std::string_view());
184     void finishLine();
185 
186     // Helpers to declare a pipeline stage IO parameter declaration.
187     void writePipelineIODeclaration(const Layout& layout,
188                                     const Type& type,
189                                     ModifierFlags modifiers,
190                                     std::string_view name,
191                                     Delimiter delimiter);
192     void writeUserDefinedIODecl(const Layout& layout,
193                                 const Type& type,
194                                 ModifierFlags modifiers,
195                                 std::string_view name,
196                                 Delimiter delimiter);
197     void writeBuiltinIODecl(const Type& type,
198                             std::string_view name,
199                             Builtin builtin,
200                             Delimiter delimiter);
201     void writeVariableDecl(const Layout& layout,
202                            const Type& type,
203                            std::string_view name,
204                            Delimiter delimiter);
205 
206     // Write a function definition.
207     void writeFunction(const FunctionDefinition& f);
208     void writeFunctionDeclaration(const FunctionDeclaration& f,
209                                   SkSpan<const bool> paramNeedsDedicatedStorage);
210 
211     // Write the program entry point.
212     void writeEntryPoint(const FunctionDefinition& f);
213 
214     // Writers for supported statement types.
215     void writeStatement(const Statement& s);
216     void writeStatements(const StatementArray& statements);
217     void writeBlock(const Block& b);
218     void writeDoStatement(const DoStatement& expr);
219     void writeExpressionStatement(const Expression& expr);
220     void writeForStatement(const ForStatement& s);
221     void writeIfStatement(const IfStatement& s);
222     void writeReturnStatement(const ReturnStatement& s);
223     void writeSwitchStatement(const SwitchStatement& s);
224     void writeSwitchCases(SkSpan<const SwitchCase* const> cases);
225     void writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase* const> cases,
226                                              std::string_view switchValue);
227     void writeSwitchCaseList(SkSpan<const SwitchCase* const> cases);
228     void writeVarDeclaration(const VarDeclaration& varDecl);
229 
230     // Synthesizes an LValue for an expression.
231     class LValue;
232     class PointerLValue;
233     class SwizzleLValue;
234     class VectorComponentLValue;
235     std::unique_ptr<LValue> makeLValue(const Expression& e);
236 
237     std::string variableReferenceNameForLValue(const VariableReference& r);
238     std::string variablePrefix(const Variable& v);
239 
240     bool binaryOpNeedsComponentwiseMatrixPolyfill(const Type& left, const Type& right, Operator op);
241 
242     // Writers for expressions. These return the final expression text as a string, and emit any
243     // necessary setup code directly into the program as necessary. The returned expression may be
244     // a `let`-alias that cannot be assigned-into; use `makeLValue` for an assignable expression.
245     std::string assembleExpression(const Expression& e, Precedence parentPrecedence);
246     std::string assembleBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
247     std::string assembleBinaryExpression(const Expression& left,
248                                          Operator op,
249                                          const Expression& right,
250                                          const Type& resultType,
251                                          Precedence parentPrecedence);
252     std::string assembleFieldAccess(const FieldAccess& f);
253     std::string assembleFunctionCall(const FunctionCall& call, Precedence parentPrecedence);
254     std::string assembleIndexExpression(const IndexExpression& i);
255     std::string assembleLiteral(const Literal& l);
256     std::string assemblePostfixExpression(const PostfixExpression& p, Precedence parentPrecedence);
257     std::string assemblePrefixExpression(const PrefixExpression& p, Precedence parentPrecedence);
258     std::string assembleSwizzle(const Swizzle& swizzle);
259     std::string assembleTernaryExpression(const TernaryExpression& t, Precedence parentPrecedence);
260     std::string assembleVariableReference(const VariableReference& r);
261     std::string assembleName(std::string_view name);
262 
263     std::string assembleIncrementExpr(const Type& type);
264 
265     // Intrinsic helper functions.
266     std::string assembleIntrinsicCall(const FunctionCall& call,
267                                       IntrinsicKind kind,
268                                       Precedence parentPrecedence);
269     std::string assembleSimpleIntrinsic(std::string_view intrinsicName, const FunctionCall& call);
270     std::string assembleUnaryOpIntrinsic(Operator op,
271                                          const FunctionCall& call,
272                                          Precedence parentPrecedence);
273     std::string assembleBinaryOpIntrinsic(Operator op,
274                                           const FunctionCall& call,
275                                           Precedence parentPrecedence);
276     std::string assembleVectorizedIntrinsic(std::string_view intrinsicName,
277                                             const FunctionCall& call);
278     std::string assembleOutAssignedIntrinsic(std::string_view intrinsicName,
279                                              std::string_view returnFieldName,
280                                              std::string_view outFieldName,
281                                              const FunctionCall& call);
282     std::string assemblePartialSampleCall(std::string_view intrinsicName,
283                                           const Expression& sampler,
284                                           const Expression& coords);
285     std::string assembleInversePolyfill(const FunctionCall& call);
286     std::string assembleComponentwiseMatrixBinary(const Type& leftType,
287                                                   const Type& rightType,
288                                                   const std::string& left,
289                                                   const std::string& right,
290                                                   Operator op);
291 
292     // Constructor expressions
293     std::string assembleAnyConstructor(const AnyConstructor& c);
294     std::string assembleConstructorCompound(const ConstructorCompound& c);
295     std::string assembleConstructorCompoundVector(const ConstructorCompound& c);
296     std::string assembleConstructorCompoundMatrix(const ConstructorCompound& c);
297     std::string assembleConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c);
298     std::string assembleConstructorMatrixResize(const ConstructorMatrixResize& ctor);
299 
300     // Synthesized helper functions for comparison operators that are not supported by WGSL.
301     std::string assembleEqualityExpression(const Type& left,
302                                            const std::string& leftName,
303                                            const Type& right,
304                                            const std::string& rightName,
305                                            Operator op,
306                                            Precedence parentPrecedence);
307     std::string assembleEqualityExpression(const Expression& left,
308                                            const Expression& right,
309                                            Operator op,
310                                            Precedence parentPrecedence);
311 
312     // Writes a scratch variable into the program and returns its name (e.g. `_skTemp123`).
313     std::string writeScratchVar(const Type& type, const std::string& value = "");
314 
315     // Writes a scratch let-variable into the program, gives it the value of `expr`, and returns its
316     // name (e.g. `_skTemp123`).
317     std::string writeScratchLet(const std::string& expr, bool isCompileTimeConstant = false);
318     std::string writeScratchLet(const Expression& expr, Precedence parentPrecedence);
319 
320     // Converts `expr` into a string and returns a scratch let-variable associated with the
321     // expression. Compile-time constants and plain variable references will return the expression
322     // directly and omit the let-variable.
323     std::string writeNontrivialScratchLet(const Expression& expr, Precedence parentPrecedence);
324 
325     // Generic recursive ProgramElement visitor.
326     void writeProgramElement(const ProgramElement& e);
327     void writeGlobalVarDeclaration(const GlobalVarDeclaration& d);
328     void writeStructDefinition(const StructDefinition& s);
329     void writeModifiersDeclaration(const ModifiersDeclaration&);
330 
331     // Writes the WGSL struct fields for SkSL structs and interface blocks. Enforces WGSL address
332     // space layout constraints
333     // (https://www.w3.org/TR/WGSL/#address-space-layout-constraints) if a `layout` is
334     // provided. A struct that does not need to be host-shareable does not require a `layout`.
335     void writeFields(SkSpan<const Field> fields, const MemoryLayout* memoryLayout = nullptr);
336 
337     // We bundle uniforms, and all varying pipeline stage inputs and outputs, into separate structs.
338     bool needsStageInputStruct() const;
339     void writeStageInputStruct();
340     bool needsStageOutputStruct() const;
341     void writeStageOutputStruct();
342     void writeUniformsAndBuffers();
343     void prepareUniformPolyfillsForInterfaceBlock(const InterfaceBlock* interfaceBlock,
344                                                   std::string_view instanceName,
345                                                   MemoryLayout::Standard nativeLayout);
346     void writeEnables();
347     void writeUniformPolyfills();
348 
349     void writeTextureOrSampler(const Variable& var,
350                                int bindingLocation,
351                                std::string_view suffix,
352                                std::string_view wgslType);
353 
354     // Writes all top-level non-opaque global uniform declarations (i.e. not part of an interface
355     // block) into a single uniform block binding.
356     //
357     // In complete fragment/vertex/compute programs, uniforms will be declared only as interface
358     // blocks and global opaque types (like textures and samplers) which we expect to be declared
359     // with a unique binding and descriptor set index. However, test files that are declared as RTE
360     // programs may contain OpenGL-style global uniform declarations with no clear binding index to
361     // use for the containing synthesized block.
362     //
363     // Since we are handling these variables only to generate gold files from RTEs and never run
364     // them, we always declare them at the default bind group and binding index.
365     void writeNonBlockUniformsForTests();
366 
367     // For a given function declaration, writes out any implicitly required pipeline stage arguments
368     // based on the function's pre-determined dependencies. These are expected to be written out as
369     // the first parameters for a function that requires them. Returns true if any arguments were
370     // written.
371     std::string functionDependencyArgs(const FunctionDeclaration&);
372     bool writeFunctionDependencyParams(const FunctionDeclaration&);
373 
374     // Code in the header appears before the main body of code.
375     StringStream fHeader;
376 
377     // We assign unique names to anonymous interface blocks based on the type.
378     skia_private::THashMap<const Type*, std::string> fInterfaceBlockNameMap;
379 
380     // Stores the functions which use stage inputs/outputs as well as required WGSL extensions.
381     ProgramRequirements fRequirements;
382     skia_private::TArray<const Variable*> fPipelineInputs;
383     skia_private::TArray<const Variable*> fPipelineOutputs;
384 
385     // These fields track whether we have written the polyfill for `inverse()` for a given matrix
386     // type.
387     bool fWrittenInverse2 = false;
388     bool fWrittenInverse3 = false;
389     bool fWrittenInverse4 = false;
390     PrettyPrint fPrettyPrint;
391     IncludeSyntheticCode fGenSyntheticCode;
392 
393     // These fields control uniform polyfill support in cases where WGSL and std140 disagree.
394     // In std140 layout, matrices need to be represented as arrays of @size(16)-aligned vectors, and
395     // array elements are wrapped in a struct containing a single @size(16)-aligned element. Arrays
396     // of matrices combine both wrappers. These wrapper structs are unpacked into natively-typed
397     // globals at the shader entrypoint.
398     struct FieldPolyfillInfo {
399         const InterfaceBlock* fInterfaceBlock;
400         std::string fReplacementName;
401         bool fIsArray = false;
402         bool fIsMatrix = false;
403         bool fWasAccessed = false;
404     };
405     using FieldPolyfillMap = skia_private::THashMap<const Field*, FieldPolyfillInfo>;
406     FieldPolyfillMap fFieldPolyfillMap;
407 
408     // Output processing state.
409     int fIndentation = 0;
410     bool fAtLineStart = false;
411     bool fHasUnconditionalReturn = false;
412     bool fAtFunctionScope = false;
413     int fConditionalScopeDepth = 0;
414     int fLocalSizeX = 1;
415     int fLocalSizeY = 1;
416     int fLocalSizeZ = 1;
417 
418     int fScratchCount = 0;
419 };
420 
421 enum class ProgramKind : int8_t;
422 
423 namespace {
424 
425 static constexpr char kSamplerSuffix[] = "_Sampler";
426 static constexpr char kTextureSuffix[] = "_Texture";
427 
428 // See https://www.w3.org/TR/WGSL/#memory-view-types
429 enum class PtrAddressSpace {
430     kFunction,
431     kPrivate,
432     kStorage,
433 };
434 
operator_name(Operator op)435 const char* operator_name(Operator op) {
436     switch (op.kind()) {
437         case Operator::Kind::LOGICALXOR:  return " != ";
438         default:                          return op.operatorName();
439     }
440 }
441 
is_reserved_word(std::string_view word)442 bool is_reserved_word(std::string_view word) {
443     static const THashSet<std::string_view> kReservedWords{
444             // Used by SkSL:
445             "FSIn",
446             "FSOut",
447             "VSIn",
448             "VSOut",
449             "CSIn",
450             "_globalUniforms",
451             "_GlobalUniforms",
452             "_return",
453             "_stageIn",
454             "_stageOut",
455             // Keywords: https://www.w3.org/TR/WGSL/#keyword-summary
456             "alias",
457             "break",
458             "case",
459             "const",
460             "const_assert",
461             "continue",
462             "continuing",
463             "default",
464             "diagnostic",
465             "discard",
466             "else",
467             "enable",
468             "false",
469             "fn",
470             "for",
471             "if",
472             "let",
473             "loop",
474             "override",
475             "requires",
476             "return",
477             "struct",
478             "switch",
479             "true",
480             "var",
481             "while",
482             // Pre-declared types: https://www.w3.org/TR/WGSL/#predeclared-types
483             "bool",
484             "f16",
485             "f32",
486             "i32",
487             "u32",
488             // ... and pre-declared type generators:
489             "array",
490             "atomic",
491             "mat2x2",
492             "mat2x3",
493             "mat2x4",
494             "mat3x2",
495             "mat3x3",
496             "mat3x4",
497             "mat4x2",
498             "mat4x3",
499             "mat4x4",
500             "ptr",
501             "texture_1d",
502             "texture_2d",
503             "texture_2d_array",
504             "texture_3d",
505             "texture_cube",
506             "texture_cube_array",
507             "texture_multisampled_2d",
508             "texture_storage_1d",
509             "texture_storage_2d",
510             "texture_storage_2d_array",
511             "texture_storage_3d",
512             "vec2",
513             "vec3",
514             "vec4",
515             // Pre-declared enumerants: https://www.w3.org/TR/WGSL/#predeclared-enumerants
516             "read",
517             "write",
518             "read_write",
519             "function",
520             "private",
521             "workgroup",
522             "uniform",
523             "storage",
524             "perspective",
525             "linear",
526             "flat",
527             "center",
528             "centroid",
529             "sample",
530             "vertex_index",
531             "instance_index",
532             "position",
533             "front_facing",
534             "frag_depth",
535             "local_invocation_id",
536             "local_invocation_index",
537             "global_invocation_id",
538             "workgroup_id",
539             "num_workgroups",
540             "sample_index",
541             "sample_mask",
542             "rgba8unorm",
543             "rgba8snorm",
544             "rgba8uint",
545             "rgba8sint",
546             "rgba16uint",
547             "rgba16sint",
548             "rgba16float",
549             "r32uint",
550             "r32sint",
551             "r32float",
552             "rg32uint",
553             "rg32sint",
554             "rg32float",
555             "rgba32uint",
556             "rgba32sint",
557             "rgba32float",
558             "bgra8unorm",
559             // Reserved words: https://www.w3.org/TR/WGSL/#reserved-words
560             "_",
561             "NULL",
562             "Self",
563             "abstract",
564             "active",
565             "alignas",
566             "alignof",
567             "as",
568             "asm",
569             "asm_fragment",
570             "async",
571             "attribute",
572             "auto",
573             "await",
574             "become",
575             "binding_array",
576             "cast",
577             "catch",
578             "class",
579             "co_await",
580             "co_return",
581             "co_yield",
582             "coherent",
583             "column_major",
584             "common",
585             "compile",
586             "compile_fragment",
587             "concept",
588             "const_cast",
589             "consteval",
590             "constexpr",
591             "constinit",
592             "crate",
593             "debugger",
594             "decltype",
595             "delete",
596             "demote",
597             "demote_to_helper",
598             "do",
599             "dynamic_cast",
600             "enum",
601             "explicit",
602             "export",
603             "extends",
604             "extern",
605             "external",
606             "fallthrough",
607             "filter",
608             "final",
609             "finally",
610             "friend",
611             "from",
612             "fxgroup",
613             "get",
614             "goto",
615             "groupshared",
616             "highp",
617             "impl",
618             "implements",
619             "import",
620             "inline",
621             "instanceof",
622             "interface",
623             "layout",
624             "lowp",
625             "macro",
626             "macro_rules",
627             "match",
628             "mediump",
629             "meta",
630             "mod",
631             "module",
632             "move",
633             "mut",
634             "mutable",
635             "namespace",
636             "new",
637             "nil",
638             "noexcept",
639             "noinline",
640             "nointerpolation",
641             "noperspective",
642             "null",
643             "nullptr",
644             "of",
645             "operator",
646             "package",
647             "packoffset",
648             "partition",
649             "pass",
650             "patch",
651             "pixelfragment",
652             "precise",
653             "precision",
654             "premerge",
655             "priv",
656             "protected",
657             "pub",
658             "public",
659             "readonly",
660             "ref",
661             "regardless",
662             "register",
663             "reinterpret_cast",
664             "require",
665             "resource",
666             "restrict",
667             "self",
668             "set",
669             "shared",
670             "sizeof",
671             "smooth",
672             "snorm",
673             "static",
674             "static_assert",
675             "static_cast",
676             "std",
677             "subroutine",
678             "super",
679             "target",
680             "template",
681             "this",
682             "thread_local",
683             "throw",
684             "trait",
685             "try",
686             "type",
687             "typedef",
688             "typeid",
689             "typename",
690             "typeof",
691             "union",
692             "unless",
693             "unorm",
694             "unsafe",
695             "unsized",
696             "use",
697             "using",
698             "varying",
699             "virtual",
700             "volatile",
701             "wgsl",
702             "where",
703             "with",
704             "writeonly",
705             "yield",
706     };
707 
708     return kReservedWords.contains(word);
709 }
710 
pipeline_struct_prefix(ProgramKind kind)711 std::string_view pipeline_struct_prefix(ProgramKind kind) {
712     if (ProgramConfig::IsVertex(kind)) {
713         return "VS";
714     }
715     if (ProgramConfig::IsFragment(kind)) {
716         return "FS";
717     }
718     if (ProgramConfig::IsCompute(kind)) {
719         return "CS";
720     }
721     // Compute programs don't have stage-in/stage-out pipeline structs.
722     return "";
723 }
724 
address_space_to_str(PtrAddressSpace addressSpace)725 std::string_view address_space_to_str(PtrAddressSpace addressSpace) {
726     switch (addressSpace) {
727         case PtrAddressSpace::kFunction:
728             return "function";
729         case PtrAddressSpace::kPrivate:
730             return "private";
731         case PtrAddressSpace::kStorage:
732             return "storage";
733     }
734     SkDEBUGFAIL("unsupported ptr address space");
735     return "unsupported";
736 }
737 
to_scalar_type(const Type & type)738 std::string_view to_scalar_type(const Type& type) {
739     SkASSERT(type.typeKind() == Type::TypeKind::kScalar);
740     switch (type.numberKind()) {
741         // Floating-point numbers in WebGPU currently always have 32-bit footprint and
742         // relaxed-precision is not supported without extensions. f32 is the only floating-point
743         // number type in WGSL (see the discussion on https://github.com/gpuweb/gpuweb/issues/658).
744         case Type::NumberKind::kFloat:
745             return "f32";
746         case Type::NumberKind::kSigned:
747             return "i32";
748         case Type::NumberKind::kUnsigned:
749             return "u32";
750         case Type::NumberKind::kBoolean:
751             return "bool";
752         case Type::NumberKind::kNonnumeric:
753             [[fallthrough]];
754         default:
755             break;
756     }
757     return type.name();
758 }
759 
760 // Convert a SkSL type to a WGSL type. Handles all plain types except structure types
761 // (see https://www.w3.org/TR/WGSL/#plain-types-section).
to_wgsl_type(const Context & context,const Type & raw,const Layout * layout=nullptr)762 std::string to_wgsl_type(const Context& context, const Type& raw, const Layout* layout = nullptr) {
763     const Type& type = raw.resolve().scalarTypeForLiteral();
764     switch (type.typeKind()) {
765         case Type::TypeKind::kScalar:
766             return std::string(to_scalar_type(type));
767 
768         case Type::TypeKind::kAtomic:
769             SkASSERT(type.matches(*context.fTypes.fAtomicUInt));
770             return "atomic<u32>";
771 
772         case Type::TypeKind::kVector: {
773             std::string_view ct = to_scalar_type(type.componentType());
774             return String::printf("vec%d<%.*s>", type.columns(), (int)ct.length(), ct.data());
775         }
776         case Type::TypeKind::kMatrix: {
777             std::string_view ct = to_scalar_type(type.componentType());
778             return String::printf("mat%dx%d<%.*s>",
779                                   type.columns(), type.rows(), (int)ct.length(), ct.data());
780         }
781         case Type::TypeKind::kArray: {
782             std::string result = "array<" + to_wgsl_type(context, type.componentType(), layout);
783             if (!type.isUnsizedArray()) {
784                 result += ", ";
785                 result += std::to_string(type.columns());
786             }
787             return result + '>';
788         }
789         case Type::TypeKind::kTexture: {
790             if (type.matches(*context.fTypes.fWriteOnlyTexture2D)) {
791                 std::string result = "texture_storage_2d<";
792                 // Write-only storage texture types require a pixel format, which is in the layout.
793                 SkASSERT(layout);
794                 LayoutFlags pixelFormat = layout->fFlags & LayoutFlag::kAllPixelFormats;
795                 switch (pixelFormat.value()) {
796                     case (int)LayoutFlag::kRGBA8:
797                         return result + "rgba8unorm, write>";
798 
799                     case (int)LayoutFlag::kRGBA32F:
800                         return result + "rgba32float, write>";
801 
802                     case (int)LayoutFlag::kR32F:
803                         return result + "r32float, write>";
804 
805                     default:
806                         // The front-end should have rejected this.
807                         return result + "write>";
808                 }
809             }
810             if (type.matches(*context.fTypes.fReadOnlyTexture2D)) {
811                 return "texture_2d<f32>";
812             }
813             break;
814         }
815         default:
816             break;
817     }
818     return std::string(type.name());
819 }
820 
to_ptr_type(const Context & context,const Type & type,const Layout * layout,PtrAddressSpace addressSpace=PtrAddressSpace::kFunction)821 std::string to_ptr_type(const Context& context,
822                         const Type& type,
823                         const Layout* layout,
824                         PtrAddressSpace addressSpace = PtrAddressSpace::kFunction) {
825     return "ptr<" + std::string(address_space_to_str(addressSpace)) + ", " +
826            to_wgsl_type(context, type, layout) + '>';
827 }
828 
wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin)829 std::string_view wgsl_builtin_name(WGSLCodeGenerator::Builtin builtin) {
830     using Builtin = WGSLCodeGenerator::Builtin;
831     switch (builtin) {
832         case Builtin::kVertexIndex:
833             return "@builtin(vertex_index)";
834         case Builtin::kInstanceIndex:
835             return "@builtin(instance_index)";
836         case Builtin::kPosition:
837             return "@builtin(position)";
838         case Builtin::kLastFragColor:
839             return "@color(0)";
840         case Builtin::kFrontFacing:
841             return "@builtin(front_facing)";
842         case Builtin::kSampleIndex:
843             return "@builtin(sample_index)";
844         case Builtin::kFragDepth:
845             return "@builtin(frag_depth)";
846         case Builtin::kSampleMask:
847         case Builtin::kSampleMaskIn:
848             return "@builtin(sample_mask)";
849         case Builtin::kLocalInvocationId:
850             return "@builtin(local_invocation_id)";
851         case Builtin::kLocalInvocationIndex:
852             return "@builtin(local_invocation_index)";
853         case Builtin::kGlobalInvocationId:
854             return "@builtin(global_invocation_id)";
855         case Builtin::kWorkgroupId:
856             return "@builtin(workgroup_id)";
857         case Builtin::kNumWorkgroups:
858             return "@builtin(num_workgroups)";
859         default:
860             break;
861     }
862 
863     SkDEBUGFAIL("unsupported builtin");
864     return "unsupported";
865 }
866 
wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin)867 std::string_view wgsl_builtin_type(WGSLCodeGenerator::Builtin builtin) {
868     using Builtin = WGSLCodeGenerator::Builtin;
869     switch (builtin) {
870         case Builtin::kVertexIndex:
871             return "u32";
872         case Builtin::kInstanceIndex:
873             return "u32";
874         case Builtin::kPosition:
875             return "vec4<f32>";
876         case Builtin::kLastFragColor:
877             return "vec4<f32>";
878         case Builtin::kFrontFacing:
879             return "bool";
880         case Builtin::kSampleIndex:
881             return "u32";
882         case Builtin::kFragDepth:
883             return "f32";
884         case Builtin::kSampleMask:
885             return "u32";
886         case Builtin::kSampleMaskIn:
887             return "u32";
888         case Builtin::kLocalInvocationId:
889             return "vec3<u32>";
890         case Builtin::kLocalInvocationIndex:
891             return "u32";
892         case Builtin::kGlobalInvocationId:
893             return "vec3<u32>";
894         case Builtin::kWorkgroupId:
895             return "vec3<u32>";
896         case Builtin::kNumWorkgroups:
897             return "vec3<u32>";
898         default:
899             break;
900     }
901 
902     SkDEBUGFAIL("unsupported builtin");
903     return "unsupported";
904 }
905 
906 // Some built-in variables have a type that differs from their SkSL counterpart (e.g. signed vs
907 // unsigned integer). We handle these cases with an explicit type conversion during a variable
908 // reference. Returns the WGSL type of the conversion target if conversion is needed, otherwise
909 // returns std::nullopt.
needs_builtin_type_conversion(const Variable & v)910 std::optional<std::string_view> needs_builtin_type_conversion(const Variable& v) {
911     switch (v.layout().fBuiltin) {
912         case SK_VERTEXID_BUILTIN:
913         case SK_INSTANCEID_BUILTIN:
914             return {"i32"};
915         default:
916             break;
917     }
918     return std::nullopt;
919 }
920 
921 // Map a SkSL builtin flag to a WGSL builtin kind. Returns std::nullopt if `builtin` is not
922 // not supported for WGSL.
923 //
924 // Also see //src/sksl/sksl_vert.sksl and //src/sksl/sksl_frag.sksl for supported built-ins.
builtin_from_sksl_name(int builtin)925 std::optional<WGSLCodeGenerator::Builtin> builtin_from_sksl_name(int builtin) {
926     using Builtin = WGSLCodeGenerator::Builtin;
927     switch (builtin) {
928         case SK_POSITION_BUILTIN:
929             [[fallthrough]];
930         case SK_FRAGCOORD_BUILTIN:
931             return Builtin::kPosition;
932         case SK_VERTEXID_BUILTIN:
933             return Builtin::kVertexIndex;
934         case SK_INSTANCEID_BUILTIN:
935             return Builtin::kInstanceIndex;
936         case SK_LASTFRAGCOLOR_BUILTIN:
937             return Builtin::kLastFragColor;
938         case SK_CLOCKWISE_BUILTIN:
939             // TODO(skia:13092): While `front_facing` is the corresponding built-in, it does not
940             // imply a particular winding order. We correctly compute the face orientation based
941             // on how Skia configured the render pipeline for all references to this built-in
942             // variable (see `SkSL::Program::Interface::fRTFlipUniform`).
943             return Builtin::kFrontFacing;
944         case SK_SAMPLEMASKIN_BUILTIN:
945             return Builtin::kSampleMaskIn;
946         case SK_SAMPLEMASK_BUILTIN:
947             return Builtin::kSampleMask;
948         case SK_NUMWORKGROUPS_BUILTIN:
949             return Builtin::kNumWorkgroups;
950         case SK_WORKGROUPID_BUILTIN:
951             return Builtin::kWorkgroupId;
952         case SK_LOCALINVOCATIONID_BUILTIN:
953             return Builtin::kLocalInvocationId;
954         case SK_GLOBALINVOCATIONID_BUILTIN:
955             return Builtin::kGlobalInvocationId;
956         case SK_LOCALINVOCATIONINDEX_BUILTIN:
957             return Builtin::kLocalInvocationIndex;
958         default:
959             break;
960     }
961     return std::nullopt;
962 }
963 
delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter)964 const char* delimiter_to_str(WGSLCodeGenerator::Delimiter delimiter) {
965     using Delim = WGSLCodeGenerator::Delimiter;
966     switch (delimiter) {
967         case Delim::kComma:
968             return ",";
969         case Delim::kSemicolon:
970             return ";";
971         case Delim::kNone:
972         default:
973             break;
974     }
975     return "";
976 }
977 
978 // FunctionDependencyResolver visits the IR tree rooted at a particular function definition and
979 // computes that function's dependencies on pipeline stage IO parameters. These are later used to
980 // synthesize arguments when writing out function definitions.
981 class FunctionDependencyResolver : public ProgramVisitor {
982 public:
983     using Deps = WGSLFunctionDependencies;
984     using DepsMap = WGSLCodeGenerator::ProgramRequirements::DepsMap;
985 
FunctionDependencyResolver(const Program * p,const FunctionDeclaration * f,DepsMap * programDependencyMap)986     FunctionDependencyResolver(const Program* p,
987                                const FunctionDeclaration* f,
988                                DepsMap* programDependencyMap)
989             : fProgram(p), fFunction(f), fDependencyMap(programDependencyMap) {}
990 
resolve()991     Deps resolve() {
992         fDeps = WGSLFunctionDependency::kNone;
993         this->visit(*fProgram);
994         return fDeps;
995     }
996 
997 private:
visitProgramElement(const ProgramElement & p)998     bool visitProgramElement(const ProgramElement& p) override {
999         // Only visit the program that matches the requested function.
1000         if (p.is<FunctionDefinition>() && &p.as<FunctionDefinition>().declaration() == fFunction) {
1001             return ProgramVisitor::visitProgramElement(p);
1002         }
1003         // Continue visiting other program elements.
1004         return false;
1005     }
1006 
visitExpression(const Expression & e)1007     bool visitExpression(const Expression& e) override {
1008         if (e.is<VariableReference>()) {
1009             const VariableReference& v = e.as<VariableReference>();
1010             if (v.variable()->storage() == Variable::Storage::kGlobal) {
1011                 ModifierFlags flags = v.variable()->modifierFlags();
1012                 if (flags & ModifierFlag::kIn) {
1013                     fDeps |= WGSLFunctionDependency::kPipelineInputs;
1014                 }
1015                 if (flags & ModifierFlag::kOut) {
1016                     fDeps |= WGSLFunctionDependency::kPipelineOutputs;
1017                 }
1018             }
1019         } else if (e.is<FunctionCall>()) {
1020             // The current function that we're processing (`fFunction`) inherits the dependencies of
1021             // functions that it makes calls to, because the pipeline stage IO parameters need to be
1022             // passed down as an argument.
1023             const FunctionCall& callee = e.as<FunctionCall>();
1024 
1025             // Don't process a function again if we have already resolved it.
1026             Deps* found = fDependencyMap->find(&callee.function());
1027             if (found) {
1028                 fDeps |= *found;
1029             } else {
1030                 // Store the dependencies that have been discovered for the current function so far.
1031                 // If `callee` directly or indirectly calls the current function, then this value
1032                 // will prevent an infinite recursion.
1033                 fDependencyMap->set(fFunction, fDeps);
1034 
1035                 // Separately traverse the called function's definition and determine its
1036                 // dependencies.
1037                 FunctionDependencyResolver resolver(fProgram, &callee.function(), fDependencyMap);
1038                 Deps calleeDeps = resolver.resolve();
1039 
1040                 // Store the callee's dependencies in the global map to avoid processing
1041                 // the function again for future calls.
1042                 fDependencyMap->set(&callee.function(), calleeDeps);
1043 
1044                 // Add to the current function's dependencies.
1045                 fDeps |= calleeDeps;
1046             }
1047         }
1048         return ProgramVisitor::visitExpression(e);
1049     }
1050 
1051     const Program* const fProgram;
1052     const FunctionDeclaration* const fFunction;
1053     DepsMap* const fDependencyMap;
1054     Deps fDeps = WGSLFunctionDependency::kNone;
1055 };
1056 
resolve_program_requirements(const Program * program)1057 WGSLCodeGenerator::ProgramRequirements resolve_program_requirements(const Program* program) {
1058     WGSLCodeGenerator::ProgramRequirements requirements;
1059 
1060     for (const ProgramElement* e : program->elements()) {
1061         switch (e->kind()) {
1062             case ProgramElement::Kind::kFunction: {
1063                 const FunctionDeclaration& decl = e->as<FunctionDefinition>().declaration();
1064 
1065                 FunctionDependencyResolver resolver(program, &decl, &requirements.fDependencies);
1066                 requirements.fDependencies.set(&decl, resolver.resolve());
1067                 break;
1068             }
1069             case ProgramElement::Kind::kGlobalVar: {
1070                 const GlobalVarDeclaration& decl = e->as<GlobalVarDeclaration>();
1071                 if (decl.varDeclaration().var()->modifierFlags().isPixelLocal()) {
1072                     requirements.fPixelLocalExtension = true;
1073                 }
1074                 break;
1075             }
1076             default:
1077                 break;
1078         }
1079     }
1080 
1081     return requirements;
1082 }
1083 
collect_pipeline_io_vars(const Program * program,TArray<const Variable * > * ioVars,ModifierFlag ioType)1084 void collect_pipeline_io_vars(const Program* program,
1085                               TArray<const Variable*>* ioVars,
1086                               ModifierFlag ioType) {
1087     for (const ProgramElement* e : program->elements()) {
1088         if (e->is<GlobalVarDeclaration>()) {
1089             const Variable* v = e->as<GlobalVarDeclaration>().varDeclaration().var();
1090             if (v->modifierFlags() & ioType) {
1091                 ioVars->push_back(v);
1092             }
1093         } else if (e->is<InterfaceBlock>()) {
1094             const Variable* v = e->as<InterfaceBlock>().var();
1095             if (v->modifierFlags() & ioType) {
1096                 ioVars->push_back(v);
1097             }
1098         }
1099     }
1100 }
1101 
is_in_global_uniforms(const Variable & var)1102 bool is_in_global_uniforms(const Variable& var) {
1103     SkASSERT(var.storage() == VariableStorage::kGlobal);
1104     return  var.modifierFlags().isUniform() &&
1105            !var.type().isOpaque() &&
1106            !var.interfaceBlock();
1107 }
1108 
1109 }  // namespace
1110 
1111 class WGSLCodeGenerator::LValue {
1112 public:
1113     virtual ~LValue() = default;
1114 
1115     // Returns a WGSL expression that loads from the lvalue with no side effects.
1116     // (e.g. `array[index].field`)
1117     virtual std::string load() = 0;
1118 
1119     // Returns a WGSL statement that stores into the lvalue with no side effects.
1120     // (e.g. `array[index].field = the_passed_in_value_string;`)
1121     virtual std::string store(const std::string& value) = 0;
1122 };
1123 
1124 class WGSLCodeGenerator::PointerLValue : public WGSLCodeGenerator::LValue {
1125 public:
1126     // `name` must be a WGSL expression with no side-effects, which we can safely take the address
1127     // of. (e.g. `array[index].field` would be valid, but `array[Func()]` or `vector.x` are not.)
PointerLValue(std::string name)1128     PointerLValue(std::string name) : fName(std::move(name)) {}
1129 
load()1130     std::string load() override {
1131         return fName;
1132     }
1133 
store(const std::string & value)1134     std::string store(const std::string& value) override {
1135         return fName + " = " + value + ";";
1136     }
1137 
1138 private:
1139     std::string fName;
1140 };
1141 
1142 class WGSLCodeGenerator::VectorComponentLValue : public WGSLCodeGenerator::LValue {
1143 public:
1144     // `name` must be a WGSL expression with no side-effects that points to a single component of a
1145     // WGSL vector.
VectorComponentLValue(std::string name)1146     VectorComponentLValue(std::string name) : fName(std::move(name)) {}
1147 
load()1148     std::string load() override {
1149         return fName;
1150     }
1151 
store(const std::string & value)1152     std::string store(const std::string& value) override {
1153         return fName + " = " + value + ";";
1154     }
1155 
1156 private:
1157     std::string fName;
1158 };
1159 
1160 class WGSLCodeGenerator::SwizzleLValue : public WGSLCodeGenerator::LValue {
1161 public:
1162     // `name` must be a WGSL expression with no side-effects that points to a WGSL vector.
SwizzleLValue(const Context & ctx,std::string name,const Type & t,const ComponentArray & c)1163     SwizzleLValue(const Context& ctx, std::string name, const Type& t, const ComponentArray& c)
1164             : fContext(ctx)
1165             , fName(std::move(name))
1166             , fType(t)
1167             , fComponents(c) {
1168         // If the component array doesn't cover the entire value, we need to create masks for
1169         // writing back into the lvalue. For example, if the type is vec4 and the component array
1170         // holds `zx`, a GLSL assignment would look like:
1171         //     name.zx = new_value;
1172         //
1173         // The equivalent WGSL assignment statement would look like:
1174         //     name = vec4<f32>(new_value, name.xw).yzxw;
1175         //
1176         // This replaces name.zy with new_value.xy, and leaves name.xw at their original values.
1177         // By convention, we always put the new value first and the original values second; it might
1178         // be possible to find better arrangements which simplify the assignment overall, but we
1179         // don't attempt this.
1180         int fullSlotCount = fType.slotCount();
1181         SkASSERT(fullSlotCount <= 4);
1182 
1183         // First, see which components are used.
1184         // The assignment swizzle must not reuse components.
1185         bool used[4] = {};
1186         for (int8_t component : fComponents) {
1187             SkASSERT(!used[component]);
1188             used[component] = true;
1189         }
1190 
1191         // Any untouched components will need to be fetched from the original value.
1192         for (int index = 0; index < fullSlotCount; ++index) {
1193             if (!used[index]) {
1194                 fUntouchedComponents.push_back(index);
1195             }
1196         }
1197 
1198         // The reintegration swizzle needs to move the components back into their proper slots.
1199         fReintegrationSwizzle.resize(fullSlotCount);
1200         int reintegrateIndex = 0;
1201 
1202         // This refills the untouched slots with the original values.
1203         auto refillUntouchedSlots = [&] {
1204             for (int index = 0; index < fullSlotCount; ++index) {
1205                 if (!used[index]) {
1206                     fReintegrationSwizzle[index] = reintegrateIndex++;
1207                 }
1208             }
1209         };
1210 
1211         // This places the new-value components into the proper slots.
1212         auto insertNewValuesIntoSlots = [&] {
1213             for (int index = 0; index < fComponents.size(); ++index) {
1214                 fReintegrationSwizzle[fComponents[index]] = reintegrateIndex++;
1215             }
1216         };
1217 
1218         // When reintegrating the untouched and new values, if the `x` slot is overwritten, we
1219         // reintegrate the new value first. Otherwise, we reintegrate the original value first.
1220         // This increases our odds of getting an identity swizzle for the reintegration.
1221         if (used[0]) {
1222             fReintegrateNewValueFirst = true;
1223             insertNewValuesIntoSlots();
1224             refillUntouchedSlots();
1225         } else {
1226             fReintegrateNewValueFirst = false;
1227             refillUntouchedSlots();
1228             insertNewValuesIntoSlots();
1229         }
1230     }
1231 
load()1232     std::string load() override {
1233         return fName + "." + Swizzle::MaskString(fComponents);
1234     }
1235 
store(const std::string & value)1236     std::string store(const std::string& value) override {
1237         // `variable = `
1238         std::string result = fName;
1239         result += " = ";
1240 
1241         if (fUntouchedComponents.empty()) {
1242             // `(new_value);`
1243             result += '(';
1244             result += value;
1245             result += ")";
1246         } else if (fReintegrateNewValueFirst) {
1247             // `vec4<f32>((new_value), `
1248             result += to_wgsl_type(fContext, fType);
1249             result += "((";
1250             result += value;
1251             result += "), ";
1252 
1253             // `variable.yz)`
1254             result += fName;
1255             result += '.';
1256             result += Swizzle::MaskString(fUntouchedComponents);
1257             result += ')';
1258         } else {
1259             // `vec4<f32>(variable.yz`
1260             result += to_wgsl_type(fContext, fType);
1261             result += '(';
1262             result += fName;
1263             result += '.';
1264             result += Swizzle::MaskString(fUntouchedComponents);
1265 
1266             // `, (new_value))`
1267             result += ", (";
1268             result += value;
1269             result += "))";
1270         }
1271 
1272         if (!Swizzle::IsIdentity(fReintegrationSwizzle)) {
1273             // `.wzyx`
1274             result += '.';
1275             result += Swizzle::MaskString(fReintegrationSwizzle);
1276         }
1277 
1278         return result + ';';
1279     }
1280 
1281 private:
1282     const Context& fContext;
1283     std::string fName;
1284     const Type& fType;
1285     ComponentArray fComponents;
1286     ComponentArray fUntouchedComponents;
1287     ComponentArray fReintegrationSwizzle;
1288     bool fReintegrateNewValueFirst = false;
1289 };
1290 
generateCode()1291 bool WGSLCodeGenerator::generateCode() {
1292     // The resources of a WGSL program are structured in the following way:
1293     // - Stage attribute inputs and outputs are bundled inside synthetic structs called
1294     //   VSIn/VSOut/FSIn/FSOut/CSIn.
1295     // - All uniform and storage type resources are declared in global scope.
1296     this->preprocessProgram();
1297 
1298     {
1299         AutoOutputStream outputToHeader(this, &fHeader, &fIndentation);
1300         this->writeEnables();
1301         this->writeStageInputStruct();
1302         this->writeStageOutputStruct();
1303         this->writeUniformsAndBuffers();
1304         this->writeNonBlockUniformsForTests();
1305     }
1306     StringStream body;
1307     {
1308         // Emit the program body.
1309         AutoOutputStream outputToBody(this, &body, &fIndentation);
1310         const FunctionDefinition* mainFunc = nullptr;
1311         for (const ProgramElement* e : fProgram.elements()) {
1312             this->writeProgramElement(*e);
1313 
1314             if (e->is<FunctionDefinition>()) {
1315                 const FunctionDefinition& func = e->as<FunctionDefinition>();
1316                 if (func.declaration().isMain()) {
1317                     mainFunc = &func;
1318                 }
1319             }
1320         }
1321 
1322         // At the bottom of the program body, emit the entrypoint function.
1323         // The entrypoint relies on state that has been collected while we emitted the rest of the
1324         // program, so it's important to do it last to make sure we don't miss anything.
1325         if (mainFunc) {
1326             this->writeEntryPoint(*mainFunc);
1327         }
1328     }
1329 
1330     write_stringstream(fHeader, *fOut);
1331     write_stringstream(body, *fOut);
1332 
1333     this->writeUniformPolyfills();
1334 
1335     return fContext.fErrors->errorCount() == 0;
1336 }
1337 
writeUniformPolyfills()1338 void WGSLCodeGenerator::writeUniformPolyfills() {
1339     // If we didn't encounter any uniforms that need polyfilling, there is nothing to do.
1340     if (fFieldPolyfillMap.empty()) {
1341         return;
1342     }
1343 
1344     // We store the list of polyfilled fields as pointers in a hash-map, so the order can be
1345     // inconsistent across runs. For determinism, we sort the polyfilled objects by name here.
1346     TArray<const FieldPolyfillMap::Pair*> orderedFields;
1347     orderedFields.reserve_exact(fFieldPolyfillMap.count());
1348 
1349     fFieldPolyfillMap.foreach([&](const FieldPolyfillMap::Pair& pair) {
1350         orderedFields.push_back(&pair);
1351     });
1352 
1353     std::sort(orderedFields.begin(),
1354               orderedFields.end(),
1355               [](const FieldPolyfillMap::Pair* a, const FieldPolyfillMap::Pair* b) {
1356                   return a->second.fReplacementName < b->second.fReplacementName;
1357               });
1358 
1359     THashSet<const Type*> writtenArrayElementPolyfill;
1360     bool writtenUniformMatrixPolyfill[5][5] = {};  // m[column][row] for each matrix type
1361     bool writtenUniformRowPolyfill[5] = {};        // for each matrix row-size
1362     bool anyFieldAccessed = false;
1363     for (const FieldPolyfillMap::Pair* pair : orderedFields) {
1364         const auto& [field, info] = *pair;
1365         const Type* fieldType = field->fType;
1366         const Layout* fieldLayout = &field->fLayout;
1367 
1368         if (info.fIsArray) {
1369             fieldType = &fieldType->componentType();
1370             if (!writtenArrayElementPolyfill.contains(fieldType)) {
1371                 writtenArrayElementPolyfill.add(fieldType);
1372                 this->write("struct _skArrayElement_");
1373                 this->write(fieldType->abbreviatedName());
1374                 this->writeLine(" {");
1375 
1376                 if (info.fIsMatrix) {
1377                     // Create a struct representing the array containing std140-padded matrices.
1378                     this->write("  e : _skMatrix");
1379                     this->write(std::to_string(fieldType->columns()));
1380                     this->writeLine(std::to_string(fieldType->rows()));
1381                 } else {
1382                     // Create a struct representing the array with extra padding between elements.
1383                     this->write("  @size(16) e : ");
1384                     this->writeLine(to_wgsl_type(fContext, *fieldType, fieldLayout));
1385                 }
1386                 this->writeLine("};");
1387             }
1388         }
1389 
1390         if (info.fIsMatrix) {
1391             // Create structs representing the matrix as an array of vectors, whether or not the
1392             // matrix is ever accessed by the SkSL. (The struct itself is mentioned in the list of
1393             // uniforms.)
1394             int c = fieldType->columns();
1395             int r = fieldType->rows();
1396             if (!writtenUniformRowPolyfill[r]) {
1397                 writtenUniformRowPolyfill[r] = true;
1398 
1399                 this->write("struct _skRow");
1400                 this->write(std::to_string(r));
1401                 this->writeLine(" {");
1402                 this->write("  @size(16) r : vec");
1403                 this->write(std::to_string(r));
1404                 this->write("<");
1405                 this->write(to_wgsl_type(fContext, fieldType->componentType(), fieldLayout));
1406                 this->writeLine(">");
1407                 this->writeLine("};");
1408             }
1409 
1410             if (!writtenUniformMatrixPolyfill[c][r]) {
1411                 writtenUniformMatrixPolyfill[c][r] = true;
1412 
1413                 this->write("struct _skMatrix");
1414                 this->write(std::to_string(c));
1415                 this->write(std::to_string(r));
1416                 this->writeLine(" {");
1417                 this->write("  c : array<_skRow");
1418                 this->write(std::to_string(r));
1419                 this->write(", ");
1420                 this->write(std::to_string(c));
1421                 this->writeLine(">");
1422                 this->writeLine("};");
1423             }
1424         }
1425 
1426         // We create a polyfill variable only if the uniform was actually accessed.
1427         if (!info.fWasAccessed) {
1428             continue;
1429         }
1430         anyFieldAccessed = true;
1431         this->write("var<private> ");
1432         this->write(info.fReplacementName);
1433         this->write(": ");
1434 
1435         const Type& interfaceBlockType = info.fInterfaceBlock->var()->type();
1436         if (interfaceBlockType.isArray()) {
1437             this->write("array<");
1438             this->write(to_wgsl_type(fContext, *field->fType, fieldLayout));
1439             this->write(", ");
1440             this->write(std::to_string(interfaceBlockType.columns()));
1441             this->write(">");
1442         } else {
1443             this->write(to_wgsl_type(fContext, *field->fType, fieldLayout));
1444         }
1445         this->writeLine(";");
1446     }
1447 
1448     // If no fields were actually accessed, _skInitializePolyfilledUniforms will not be called and
1449     // we can avoid emitting an empty, dead function.
1450     if (!anyFieldAccessed) {
1451         return;
1452     }
1453 
1454     this->writeLine("fn _skInitializePolyfilledUniforms() {");
1455     ++fIndentation;
1456 
1457     for (const FieldPolyfillMap::Pair* pair : orderedFields) {
1458         // Only initialize a polyfill global if the uniform was actually accessed.
1459         const auto& [field, info] = *pair;
1460         if (!info.fWasAccessed) {
1461             continue;
1462         }
1463 
1464         // Synthesize the name of this uniform variable
1465         std::string_view instanceName = info.fInterfaceBlock->instanceName();
1466         const Type& interfaceBlockType = info.fInterfaceBlock->var()->type();
1467         if (instanceName.empty()) {
1468             instanceName = fInterfaceBlockNameMap[&interfaceBlockType.componentType()];
1469         }
1470 
1471         // Initialize the global variable associated with this uniform.
1472         // If the interface block is arrayed, the associated global will be arrayed as well.
1473         int numIBElements = interfaceBlockType.isArray() ? interfaceBlockType.columns() : 1;
1474         for (int ibIdx = 0; ibIdx < numIBElements; ++ibIdx) {
1475             this->write(info.fReplacementName);
1476             if (interfaceBlockType.isArray()) {
1477                 this->write("[");
1478                 this->write(std::to_string(ibIdx));
1479                 this->write("]");
1480             }
1481             this->write(" = ");
1482 
1483             const Type* fieldType = field->fType;
1484             const Layout* fieldLayout = &field->fLayout;
1485 
1486             int numArrayElements;
1487             if (info.fIsArray) {
1488                 this->write(to_wgsl_type(fContext, *fieldType, fieldLayout));
1489                 this->write("(");
1490                 numArrayElements = fieldType->columns();
1491                 fieldType = &fieldType->componentType();
1492             } else {
1493                 numArrayElements = 1;
1494             }
1495 
1496             auto arraySeparator = String::Separator();
1497             for (int arrayIdx = 0; arrayIdx < numArrayElements; arrayIdx++) {
1498                 this->write(arraySeparator());
1499 
1500                 std::string fieldName{instanceName};
1501                 if (interfaceBlockType.isArray()) {
1502                     fieldName += '[';
1503                     fieldName += std::to_string(ibIdx);
1504                     fieldName += ']';
1505                 }
1506                 fieldName += '.';
1507                 fieldName += this->assembleName(field->fName);
1508 
1509                 if (info.fIsArray) {
1510                     fieldName += '[';
1511                     fieldName += std::to_string(arrayIdx);
1512                     fieldName += "].e";
1513                 }
1514 
1515                 if (info.fIsMatrix) {
1516                     this->write(to_wgsl_type(fContext, *fieldType, fieldLayout));
1517                     this->write("(");
1518                     int numColumns = fieldType->columns();
1519                     auto matrixSeparator = String::Separator();
1520                     for (int column = 0; column < numColumns; column++) {
1521                         this->write(matrixSeparator());
1522                         this->write(fieldName);
1523                         this->write(".c[");
1524                         this->write(std::to_string(column));
1525                         this->write("].r");
1526                     }
1527                     this->write(")");
1528                 } else {
1529                     this->write(fieldName);
1530                 }
1531             }
1532 
1533             if (info.fIsArray) {
1534                 this->write(")");
1535             }
1536 
1537             this->writeLine(";");
1538         }
1539     }
1540 
1541     --fIndentation;
1542     this->writeLine("}");
1543 }
1544 
1545 
preprocessProgram()1546 void WGSLCodeGenerator::preprocessProgram() {
1547     fRequirements = resolve_program_requirements(&fProgram);
1548     collect_pipeline_io_vars(&fProgram, &fPipelineInputs, ModifierFlag::kIn);
1549     collect_pipeline_io_vars(&fProgram, &fPipelineOutputs, ModifierFlag::kOut);
1550 }
1551 
write(std::string_view s)1552 void WGSLCodeGenerator::write(std::string_view s) {
1553     if (s.empty()) {
1554         return;
1555     }
1556     if (fAtLineStart && fPrettyPrint == PrettyPrint::kYes) {
1557         for (int i = 0; i < fIndentation; i++) {
1558             fOut->writeText("  ");
1559         }
1560     }
1561     fOut->writeText(std::string(s).c_str());
1562     fAtLineStart = false;
1563 }
1564 
writeLine(std::string_view s)1565 void WGSLCodeGenerator::writeLine(std::string_view s) {
1566     this->write(s);
1567     fOut->writeText("\n");
1568     fAtLineStart = true;
1569 }
1570 
finishLine()1571 void WGSLCodeGenerator::finishLine() {
1572     if (!fAtLineStart) {
1573         this->writeLine();
1574     }
1575 }
1576 
assembleName(std::string_view name)1577 std::string WGSLCodeGenerator::assembleName(std::string_view name) {
1578     if (name.empty()) {
1579         // WGSL doesn't allow anonymous function parameters.
1580         return "_skAnonymous" + std::to_string(fScratchCount++);
1581     }
1582     // Add `R_` before reserved names to avoid any potential reserved-word conflict.
1583     return (skstd::starts_with(name, "_sk") ||
1584             skstd::starts_with(name, "R_") ||
1585             is_reserved_word(name))
1586                    ? std::string("R_") + std::string(name)
1587                    : std::string(name);
1588 }
1589 
writeVariableDecl(const Layout & layout,const Type & type,std::string_view name,Delimiter delimiter)1590 void WGSLCodeGenerator::writeVariableDecl(const Layout& layout,
1591                                           const Type& type,
1592                                           std::string_view name,
1593                                           Delimiter delimiter) {
1594     this->write(this->assembleName(name));
1595     this->write(": " + to_wgsl_type(fContext, type, &layout));
1596     this->writeLine(delimiter_to_str(delimiter));
1597 }
1598 
writePipelineIODeclaration(const Layout & layout,const Type & type,ModifierFlags modifiers,std::string_view name,Delimiter delimiter)1599 void WGSLCodeGenerator::writePipelineIODeclaration(const Layout& layout,
1600                                                    const Type& type,
1601                                                    ModifierFlags modifiers,
1602                                                    std::string_view name,
1603                                                    Delimiter delimiter) {
1604     // In WGSL, an entry-point IO parameter is "one of either a built-in value or assigned a
1605     // location". However, some SkSL declarations, specifically sk_FragColor, can contain both a
1606     // location and a builtin modifier. In addition, WGSL doesn't have a built-in equivalent for
1607     // sk_FragColor as it relies on the user-defined location for a render target.
1608     //
1609     // Instead of special-casing sk_FragColor, we just give higher precedence to a location modifier
1610     // if a declaration happens to both have a location and it's a built-in.
1611     //
1612     // Also see:
1613     // https://www.w3.org/TR/WGSL/#input-output-locations
1614     // https://www.w3.org/TR/WGSL/#attribute-location
1615     // https://www.w3.org/TR/WGSL/#builtin-inputs-outputs
1616     if (layout.fLocation >= 0) {
1617         this->writeUserDefinedIODecl(layout, type, modifiers, name, delimiter);
1618         return;
1619     }
1620     if (layout.fBuiltin >= 0) {
1621         if (layout.fBuiltin == SK_POINTSIZE_BUILTIN) {
1622             // WebGPU does not support the point-size builtin, but we silently replace it with a
1623             // global variable when it is used, instead of reporting an error.
1624             return;
1625         }
1626         auto builtin = builtin_from_sksl_name(layout.fBuiltin);
1627         if (builtin.has_value()) {
1628             // Builtin IO parameters should only have in/out modifiers, which are then implicit in
1629             // the generated WGSL, hence why writeBuiltinIODecl does not need them passed in.
1630             SkASSERT(!(modifiers & ~(ModifierFlag::kIn | ModifierFlag::kOut)));
1631             this->writeBuiltinIODecl(type, name, *builtin, delimiter);
1632             return;
1633         }
1634     }
1635     fContext.fErrors->error(Position(), "declaration '" + std::string(name) + "' is not supported");
1636 }
1637 
writeUserDefinedIODecl(const Layout & layout,const Type & type,ModifierFlags flags,std::string_view name,Delimiter delimiter)1638 void WGSLCodeGenerator::writeUserDefinedIODecl(const Layout& layout,
1639                                                const Type& type,
1640                                                ModifierFlags flags,
1641                                                std::string_view name,
1642                                                Delimiter delimiter) {
1643     this->write("@location(" + std::to_string(layout.fLocation) + ") ");
1644 
1645     // @blend_src is only allowed when doing dual-source blending, and only on color attachment 0.
1646     if (layout.fLocation == 0 && layout.fIndex >= 0 && fProgram.fInterface.fOutputSecondaryColor) {
1647         this->write("@blend_src(" + std::to_string(layout.fIndex) + ") ");
1648     }
1649 
1650     // "User-defined IO of scalar or vector integer type must always be specified as
1651     // @interpolate(flat)" (see https://www.w3.org/TR/WGSL/#interpolation)
1652     if (flags.isFlat() || type.isInteger() ||
1653         (type.isVector() && type.componentType().isInteger())) {
1654         // We can use 'either' to hint to WebGPU that we don't care about the provoking vertex and
1655         // avoid any expensive shader or data rewriting to ensure 'first'. Skia has a long-standing
1656         // policy to only use flat shading when it's constant for a primitive so the vertex doesn't
1657         // matter. See https://www.w3.org/TR/WGSL/#interpolation-sampling-either
1658         this->write("@interpolate(flat, either) ");
1659     } else if (flags & ModifierFlag::kNoPerspective) {
1660         this->write("@interpolate(linear) ");
1661     }
1662 
1663     this->writeVariableDecl(layout, type, name, delimiter);
1664 }
1665 
writeBuiltinIODecl(const Type & type,std::string_view name,Builtin builtin,Delimiter delimiter)1666 void WGSLCodeGenerator::writeBuiltinIODecl(const Type& type,
1667                                            std::string_view name,
1668                                            Builtin builtin,
1669                                            Delimiter delimiter) {
1670     this->write(wgsl_builtin_name(builtin));
1671     this->write(" ");
1672     this->write(this->assembleName(name));
1673     this->write(": ");
1674     this->write(wgsl_builtin_type(builtin));
1675     this->writeLine(delimiter_to_str(delimiter));
1676 }
1677 
writeFunction(const FunctionDefinition & f)1678 void WGSLCodeGenerator::writeFunction(const FunctionDefinition& f) {
1679     const FunctionDeclaration& decl = f.declaration();
1680     fHasUnconditionalReturn = false;
1681     fConditionalScopeDepth = 0;
1682 
1683     SkASSERT(!fAtFunctionScope);
1684     fAtFunctionScope = true;
1685 
1686     // WGSL parameters are immutable and are considered as taking no storage, but SkSL parameters
1687     // are real variables. To work around this, we make var-based copies of parameters. It's
1688     // wasteful to make a copy of every single parameter--even if the compiler can eventually
1689     // optimize them all away, that takes time and generates bloated code. So, we only make
1690     // parameter copies if the variable is actually written-to.
1691     STArray<32, bool> paramNeedsDedicatedStorage;
1692     paramNeedsDedicatedStorage.push_back_n(decl.parameters().size(), true);
1693 
1694     for (size_t index = 0; index < decl.parameters().size(); ++index) {
1695         const Variable& param = *decl.parameters()[index];
1696         if (param.type().isOpaque() || param.name().empty()) {
1697             // Opaque-typed or anonymous parameters don't need dedicated storage.
1698             paramNeedsDedicatedStorage[index] = false;
1699             continue;
1700         }
1701 
1702         const ProgramUsage::VariableCounts counts = fProgram.fUsage->get(param);
1703         if ((param.modifierFlags() & ModifierFlag::kOut) || counts.fWrite == 0) {
1704             // Variables which are never written-to don't need dedicated storage.
1705             // Out-parameters are passed as pointers; the pointer itself is never modified, so
1706             // it doesn't need dedicated storage.
1707             paramNeedsDedicatedStorage[index] = false;
1708         }
1709     }
1710 
1711     this->writeFunctionDeclaration(decl, paramNeedsDedicatedStorage);
1712     this->writeLine(" {");
1713     ++fIndentation;
1714 
1715     // The parameters were given generic names like `_skParam1`, because WGSL parameters don't have
1716     // storage and are immutable. If mutability is required, we create variables here; otherwise, we
1717     // create properly-named `let` aliases.
1718     for (size_t index = 0; index < decl.parameters().size(); ++index) {
1719         if (paramNeedsDedicatedStorage[index]) {
1720             const Variable& param = *decl.parameters()[index];
1721             this->write("var ");
1722             this->write(this->assembleName(param.mangledName()));
1723             this->write(" = _skParam");
1724             this->write(std::to_string(index));
1725             this->writeLine(";");
1726         }
1727     }
1728 
1729     this->writeBlock(f.body()->as<Block>());
1730 
1731     // If fConditionalScopeDepth isn't zero, we have an unbalanced +1 or -1 when updating the depth.
1732     SkASSERT(fConditionalScopeDepth == 0);
1733     if (!fHasUnconditionalReturn && !f.declaration().returnType().isVoid()) {
1734         this->write("return ");
1735         this->write(to_wgsl_type(fContext, f.declaration().returnType()));
1736         this->writeLine("();");
1737     }
1738 
1739     --fIndentation;
1740     this->writeLine("}");
1741 
1742     SkASSERT(fAtFunctionScope);
1743     fAtFunctionScope = false;
1744 }
1745 
writeFunctionDeclaration(const FunctionDeclaration & decl,SkSpan<const bool> paramNeedsDedicatedStorage)1746 void WGSLCodeGenerator::writeFunctionDeclaration(const FunctionDeclaration& decl,
1747                                                  SkSpan<const bool> paramNeedsDedicatedStorage) {
1748     this->write("fn ");
1749     if (decl.isMain()) {
1750         this->write("_skslMain(");
1751     } else {
1752         this->write(this->assembleName(decl.mangledName()));
1753         this->write("(");
1754     }
1755     auto separator = SkSL::String::Separator();
1756     if (this->writeFunctionDependencyParams(decl)) {
1757         separator();  // update the separator as parameters have been written
1758     }
1759     for (size_t index = 0; index < decl.parameters().size(); ++index) {
1760         this->write(separator());
1761 
1762         const Variable& param = *decl.parameters()[index];
1763         if (param.type().isOpaque()) {
1764             SkASSERT(!paramNeedsDedicatedStorage[index]);
1765             if (param.type().isSampler()) {
1766                 // Create parameters for both the texture and associated sampler.
1767                 this->write(param.name());
1768                 this->write(kTextureSuffix);
1769                 this->write(": texture_2d<f32>, ");
1770                 this->write(param.name());
1771                 this->write(kSamplerSuffix);
1772                 this->write(": sampler");
1773             } else {
1774                 // Create a parameter for the opaque object.
1775                 this->write(param.name());
1776                 this->write(": ");
1777                 this->write(to_wgsl_type(fContext, param.type(), &param.layout()));
1778             }
1779         } else {
1780             if (paramNeedsDedicatedStorage[index] || param.name().empty()) {
1781                 // Create an unnamed parameter. If the parameter needs dedicated storage, it will
1782                 // later be assigned a `var` in the function body. (If it's anonymous, a var isn't
1783                 // needed.)
1784                 this->write("_skParam");
1785                 this->write(std::to_string(index));
1786             } else {
1787                 // Use the name directly from the SkSL program.
1788                 this->write(this->assembleName(param.name()));
1789             }
1790             this->write(": ");
1791             if (param.type().isUnsizedArray()) {
1792                 // Creates a storage address space pointer for unsized array parameters.
1793                 // The buffer the array resides in must be marked `readonly` to have the array
1794                 // be used in function parameters, since access modes in wgsl must exactly match.
1795                 this->write("ptr<storage, ");
1796                 this->write(to_wgsl_type(fContext, param.type(), &param.layout()));
1797                 this->write(", read>");
1798             } else if (param.modifierFlags() & ModifierFlag::kOut) {
1799                 // Declare an "out" function parameter as a pointer.
1800                 this->write(to_ptr_type(fContext, param.type(), &param.layout()));
1801             } else {
1802                 this->write(to_wgsl_type(fContext, param.type(), &param.layout()));
1803             }
1804         }
1805     }
1806     this->write(")");
1807     if (!decl.returnType().isVoid()) {
1808         this->write(" -> ");
1809         this->write(to_wgsl_type(fContext, decl.returnType()));
1810     }
1811 }
1812 
writeEntryPoint(const FunctionDefinition & main)1813 void WGSLCodeGenerator::writeEntryPoint(const FunctionDefinition& main) {
1814     SkASSERT(main.declaration().isMain());
1815     const ProgramKind programKind = fProgram.fConfig->fKind;
1816 
1817     if (fGenSyntheticCode == IncludeSyntheticCode::kYes &&
1818         ProgramConfig::IsRuntimeShader(programKind)) {
1819         // Synthesize a basic entrypoint which just calls straight through to main.
1820         // This is only used by skslc and just needs to pass the WGSL validator; Skia won't ever
1821         // emit functions like this.
1822         this->writeLine("@fragment fn main(@location(0) _coords: vec2<f32>) -> "
1823                                      "@location(0) vec4<f32> {");
1824         ++fIndentation;
1825         this->writeLine("return _skslMain(_coords);");
1826         --fIndentation;
1827         this->writeLine("}");
1828         return;
1829     }
1830 
1831     // The input and output parameters for a vertex/fragment stage entry point function have the
1832     // FSIn/FSOut/VSIn/VSOut/CSIn struct types that have been synthesized in generateCode(). An
1833     // entrypoint always has a predictable signature and acts as a trampoline to the user-defined
1834     // main function.
1835     if (ProgramConfig::IsVertex(programKind)) {
1836         this->write("@vertex");
1837     } else if (ProgramConfig::IsFragment(programKind)) {
1838         this->write("@fragment");
1839     } else if (ProgramConfig::IsCompute(programKind)) {
1840         this->write("@compute @workgroup_size(");
1841         this->write(std::to_string(fLocalSizeX));
1842         this->write(", ");
1843         this->write(std::to_string(fLocalSizeY));
1844         this->write(", ");
1845         this->write(std::to_string(fLocalSizeZ));
1846         this->write(")");
1847     } else {
1848         fContext.fErrors->error(Position(), "program kind not supported");
1849         return;
1850     }
1851 
1852     this->write(" fn main(");
1853     // The stage input struct is a parameter passed to main().
1854     if (this->needsStageInputStruct()) {
1855         this->write("_stageIn: ");
1856         this->write(pipeline_struct_prefix(programKind));
1857         this->write("In");
1858     }
1859     // The stage output struct is returned from main().
1860     if (this->needsStageOutputStruct()) {
1861         this->write(") -> ");
1862         this->write(pipeline_struct_prefix(programKind));
1863         this->writeLine("Out {");
1864     } else {
1865         this->writeLine(") {");
1866     }
1867     // Initialize polyfilled matrix uniforms if any were used.
1868     fIndentation++;
1869     for (const auto& [field, info] : fFieldPolyfillMap) {
1870         if (info.fWasAccessed) {
1871             this->writeLine("_skInitializePolyfilledUniforms();");
1872             break;
1873         }
1874     }
1875     // Declare the stage output struct.
1876     if (this->needsStageOutputStruct()) {
1877         this->write("var _stageOut: ");
1878         this->write(pipeline_struct_prefix(programKind));
1879         this->writeLine("Out;");
1880     }
1881 
1882     // We are compiling a Runtime Effect as a fragment shader, for testing purposes. We assign the
1883     // result from _skslMain into sk_FragColor if the user-defined main returns a color. This
1884     // doesn't actually matter, but it is more indicative of what a real program would do.
1885     // `addImplicitFragColorWrite` from Transform::FindAndDeclareBuiltinVariables has already
1886     // injected sk_FragColor into our stage outputs even if it wasn't explicitly referenced.
1887     if (fGenSyntheticCode == IncludeSyntheticCode::kYes && ProgramConfig::IsFragment(programKind)) {
1888         if (main.declaration().returnType().matches(*fContext.fTypes.fHalf4)) {
1889             this->write("_stageOut.sk_FragColor = ");
1890         }
1891     }
1892 
1893     // Generate a function call to the user-defined main.
1894     this->write("_skslMain(");
1895     auto separator = SkSL::String::Separator();
1896     WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&main.declaration());
1897     if (deps) {
1898         if (*deps & WGSLFunctionDependency::kPipelineInputs) {
1899             this->write(separator());
1900             this->write("_stageIn");
1901         }
1902         if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
1903             this->write(separator());
1904             this->write("&_stageOut");
1905         }
1906     }
1907 
1908     if (fGenSyntheticCode == IncludeSyntheticCode::kYes) {
1909         if (const Variable* v = main.declaration().getMainCoordsParameter()) {
1910             // We are compiling a Runtime Effect as a fragment shader, for testing purposes.
1911             // We need to synthesize a coordinates parameter, but the coordinates don't matter.
1912             SkASSERT(ProgramConfig::IsFragment(programKind));
1913             const Type& type = v->type();
1914             if (!type.matches(*fContext.fTypes.fFloat2)) {
1915                 fContext.fErrors->error(
1916                         main.fPosition,
1917                         "main function has unsupported parameter: " + type.description());
1918                 return;
1919             }
1920             this->write(separator());
1921             this->write("/*fragcoord*/ vec2<f32>()");
1922         }
1923     }
1924 
1925     this->writeLine(");");
1926 
1927     if (this->needsStageOutputStruct()) {
1928         // Return the stage output struct.
1929         this->writeLine("return _stageOut;");
1930     }
1931 
1932     fIndentation--;
1933     this->writeLine("}");
1934 }
1935 
writeStatement(const Statement & s)1936 void WGSLCodeGenerator::writeStatement(const Statement& s) {
1937     switch (s.kind()) {
1938         case Statement::Kind::kBlock:
1939             this->writeBlock(s.as<Block>());
1940             break;
1941         case Statement::Kind::kBreak:
1942             this->writeLine("break;");
1943             break;
1944         case Statement::Kind::kContinue:
1945             this->writeLine("continue;");
1946             break;
1947         case Statement::Kind::kDiscard:
1948             this->writeLine("discard;");
1949             break;
1950         case Statement::Kind::kDo:
1951             this->writeDoStatement(s.as<DoStatement>());
1952             break;
1953         case Statement::Kind::kExpression:
1954             this->writeExpressionStatement(*s.as<ExpressionStatement>().expression());
1955             break;
1956         case Statement::Kind::kFor:
1957             this->writeForStatement(s.as<ForStatement>());
1958             break;
1959         case Statement::Kind::kIf:
1960             this->writeIfStatement(s.as<IfStatement>());
1961             break;
1962         case Statement::Kind::kNop:
1963             this->writeLine(";");
1964             break;
1965         case Statement::Kind::kReturn:
1966             this->writeReturnStatement(s.as<ReturnStatement>());
1967             break;
1968         case Statement::Kind::kSwitch:
1969             this->writeSwitchStatement(s.as<SwitchStatement>());
1970             break;
1971         case Statement::Kind::kSwitchCase:
1972             SkDEBUGFAIL("switch-case statements should only be present inside a switch");
1973             break;
1974         case Statement::Kind::kVarDeclaration:
1975             this->writeVarDeclaration(s.as<VarDeclaration>());
1976             break;
1977     }
1978 }
1979 
writeStatements(const StatementArray & statements)1980 void WGSLCodeGenerator::writeStatements(const StatementArray& statements) {
1981     for (const auto& s : statements) {
1982         if (!s->isEmpty()) {
1983             this->writeStatement(*s);
1984             this->finishLine();
1985         }
1986     }
1987 }
1988 
writeBlock(const Block & b)1989 void WGSLCodeGenerator::writeBlock(const Block& b) {
1990     // Write scope markers if this block is a scope, or if the block is empty (since we need to emit
1991     // something here to make the code valid).
1992     bool isScope = b.isScope() || b.isEmpty();
1993     if (isScope) {
1994         this->writeLine("{");
1995         fIndentation++;
1996     }
1997     this->writeStatements(b.children());
1998     if (isScope) {
1999         fIndentation--;
2000         this->writeLine("}");
2001     }
2002 }
2003 
writeExpressionStatement(const Expression & expr)2004 void WGSLCodeGenerator::writeExpressionStatement(const Expression& expr) {
2005     // Any expression-related side effects must be emitted as separate statements when
2006     // `assembleExpression` is called.
2007     // The final result of the expression will be a variable, let-reference, or an expression with
2008     // no side effects (`foo + bar`). Discarding this result is safe, as the program never uses it.
2009     (void)this->assembleExpression(expr, Precedence::kStatement);
2010 }
2011 
writeDoStatement(const DoStatement & s)2012 void WGSLCodeGenerator::writeDoStatement(const DoStatement& s) {
2013     // Generate a loop structure like this:
2014     //   loop {
2015     //       body-statement;
2016     //       continuing {
2017     //           break if inverted-test-expression;
2018     //       }
2019     //   }
2020 
2021     ++fConditionalScopeDepth;
2022 
2023     std::unique_ptr<Expression> invertedTestExpr = PrefixExpression::Make(
2024             fContext, s.test()->fPosition, OperatorKind::LOGICALNOT, s.test()->clone());
2025 
2026     this->writeLine("loop {");
2027     fIndentation++;
2028     this->writeStatement(*s.statement());
2029     this->finishLine();
2030 
2031     this->writeLine("continuing {");
2032     fIndentation++;
2033     std::string breakIfExpr = this->assembleExpression(*invertedTestExpr, Precedence::kExpression);
2034     this->write("break if ");
2035     this->write(breakIfExpr);
2036     this->writeLine(";");
2037     fIndentation--;
2038     this->writeLine("}");
2039     fIndentation--;
2040     this->writeLine("}");
2041 
2042     --fConditionalScopeDepth;
2043 }
2044 
writeForStatement(const ForStatement & s)2045 void WGSLCodeGenerator::writeForStatement(const ForStatement& s) {
2046     // Generate a loop structure wrapped in an extra scope:
2047     //   {
2048     //     initializer-statement;
2049     //     loop;
2050     //   }
2051     // The outer scope is necessary to prevent the initializer-variable from leaking out into the
2052     // rest of the code. In practice, the generated code actually tends to be scoped even more
2053     // deeply, as the body-statement almost always contributes an extra block.
2054 
2055     ++fConditionalScopeDepth;
2056 
2057     if (s.initializer()) {
2058         this->writeLine("{");
2059         fIndentation++;
2060         this->writeStatement(*s.initializer());
2061         this->writeLine();
2062     }
2063 
2064     this->writeLine("loop {");
2065     fIndentation++;
2066 
2067     if (s.unrollInfo()) {
2068         if (s.unrollInfo()->fCount <= 0) {
2069             // Loops which are known to never execute don't need to be emitted at all.
2070             // (However, the front end should have already replaced this loop with a Nop.)
2071         } else {
2072             // Loops which are known to execute at least once can use this form:
2073             //
2074             //     loop {
2075             //         body-statement;
2076             //         continuing {
2077             //             next-expression;
2078             //             break if inverted-test-expression;
2079             //         }
2080             //     }
2081 
2082             this->writeStatement(*s.statement());
2083             this->finishLine();
2084             this->writeLine("continuing {");
2085             ++fIndentation;
2086 
2087             if (s.next()) {
2088                 this->writeExpressionStatement(*s.next());
2089                 this->finishLine();
2090             }
2091 
2092             if (s.test()) {
2093                 std::unique_ptr<Expression> invertedTestExpr = PrefixExpression::Make(
2094                         fContext, s.test()->fPosition, OperatorKind::LOGICALNOT, s.test()->clone());
2095 
2096                 std::string breakIfExpr =
2097                         this->assembleExpression(*invertedTestExpr, Precedence::kExpression);
2098                 this->write("break if ");
2099                 this->write(breakIfExpr);
2100                 this->writeLine(";");
2101             }
2102 
2103             --fIndentation;
2104             this->writeLine("}");
2105         }
2106     } else {
2107         // Loops without a known execution count are emitted in this form:
2108         //
2109         //     loop {
2110         //         if test-expression {
2111         //             body-statement;
2112         //         } else {
2113         //             break;
2114         //         }
2115         //         continuing {
2116         //             next-expression;
2117         //         }
2118         //     }
2119 
2120         if (s.test()) {
2121             std::string testExpr = this->assembleExpression(*s.test(), Precedence::kExpression);
2122             this->write("if ");
2123             this->write(testExpr);
2124             this->writeLine(" {");
2125 
2126             fIndentation++;
2127             this->writeStatement(*s.statement());
2128             this->finishLine();
2129             fIndentation--;
2130 
2131             this->writeLine("} else {");
2132 
2133             fIndentation++;
2134             this->writeLine("break;");
2135             fIndentation--;
2136 
2137             this->writeLine("}");
2138         }
2139         else {
2140             this->writeStatement(*s.statement());
2141             this->finishLine();
2142         }
2143 
2144         if (s.next()) {
2145             this->writeLine("continuing {");
2146             fIndentation++;
2147             this->writeExpressionStatement(*s.next());
2148             this->finishLine();
2149             fIndentation--;
2150             this->writeLine("}");
2151         }
2152     }
2153 
2154     // This matches an open-brace at the top of the loop.
2155     fIndentation--;
2156     this->writeLine("}");
2157 
2158     if (s.initializer()) {
2159         // This matches an open-brace before the initializer-statement.
2160         fIndentation--;
2161         this->writeLine("}");
2162     }
2163 
2164     --fConditionalScopeDepth;
2165 }
2166 
writeIfStatement(const IfStatement & s)2167 void WGSLCodeGenerator::writeIfStatement(const IfStatement& s) {
2168     ++fConditionalScopeDepth;
2169 
2170     std::string testExpr = this->assembleExpression(*s.test(), Precedence::kExpression);
2171     this->write("if ");
2172     this->write(testExpr);
2173     this->writeLine(" {");
2174     fIndentation++;
2175     this->writeStatement(*s.ifTrue());
2176     this->finishLine();
2177     fIndentation--;
2178     if (s.ifFalse()) {
2179         this->writeLine("} else {");
2180         fIndentation++;
2181         this->writeStatement(*s.ifFalse());
2182         this->finishLine();
2183         fIndentation--;
2184     }
2185     this->writeLine("}");
2186 
2187     --fConditionalScopeDepth;
2188 }
2189 
writeReturnStatement(const ReturnStatement & s)2190 void WGSLCodeGenerator::writeReturnStatement(const ReturnStatement& s) {
2191     fHasUnconditionalReturn |= (fConditionalScopeDepth == 0);
2192 
2193     std::string expr = s.expression()
2194                                ? this->assembleExpression(*s.expression(), Precedence::kExpression)
2195                                : std::string();
2196     this->write("return ");
2197     this->write(expr);
2198     this->write(";");
2199 }
2200 
writeSwitchCaseList(SkSpan<const SwitchCase * const> cases)2201 void WGSLCodeGenerator::writeSwitchCaseList(SkSpan<const SwitchCase* const> cases) {
2202     auto separator = SkSL::String::Separator();
2203     for (const SwitchCase* const sc : cases) {
2204         this->write(separator());
2205         if (sc->isDefault()) {
2206             this->write("default");
2207         } else {
2208             this->write(std::to_string(sc->value()));
2209         }
2210     }
2211 }
2212 
writeSwitchCases(SkSpan<const SwitchCase * const> cases)2213 void WGSLCodeGenerator::writeSwitchCases(SkSpan<const SwitchCase* const> cases) {
2214     if (!cases.empty()) {
2215         // Only the last switch-case should have a non-empty statement.
2216         SkASSERT(std::all_of(cases.begin(), std::prev(cases.end()), [](const SwitchCase* sc) {
2217             return sc->statement()->isEmpty();
2218         }));
2219 
2220         // Emit the cases in a comma-separated list.
2221         this->write("case ");
2222         this->writeSwitchCaseList(cases);
2223         this->writeLine(" {");
2224         ++fIndentation;
2225 
2226         // Emit the switch-case body.
2227         this->writeStatement(*cases.back()->statement());
2228         this->finishLine();
2229 
2230         --fIndentation;
2231         this->writeLine("}");
2232     }
2233 }
2234 
writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase * const> cases,std::string_view switchValue)2235 void WGSLCodeGenerator::writeEmulatedSwitchFallthroughCases(SkSpan<const SwitchCase* const> cases,
2236                                                             std::string_view switchValue) {
2237     // There's no need for fallthrough handling unless we actually have multiple case blocks.
2238     if (cases.size() < 2) {
2239         this->writeSwitchCases(cases);
2240         return;
2241     }
2242 
2243     // Match against the entire case group.
2244     this->write("case ");
2245     this->writeSwitchCaseList(cases);
2246     this->writeLine(" {");
2247     ++fIndentation;
2248 
2249     std::string fallthroughVar = this->writeScratchVar(*fContext.fTypes.fBool, "false");
2250     const size_t secondToLastCaseIndex = cases.size() - 2;
2251     const size_t lastCaseIndex = cases.size() - 1;
2252 
2253     for (size_t index = 0; index < cases.size(); ++index) {
2254         const SwitchCase& sc = *cases[index];
2255         if (index < lastCaseIndex) {
2256             // The default case must come last in SkSL, and this case isn't the last one, so it
2257             // can't possibly be the default.
2258             SkASSERT(!sc.isDefault());
2259 
2260             this->write("if ");
2261             if (index > 0) {
2262                 this->write(fallthroughVar);
2263                 this->write(" || ");
2264             }
2265             this->write(switchValue);
2266             this->write(" == ");
2267             this->write(std::to_string(sc.value()));
2268             this->writeLine(" {");
2269             fIndentation++;
2270 
2271             // We write the entire case-block statement here, and then set `switchFallthrough`
2272             // to 1. If the case-block had a break statement in it, we break out of the outer
2273             // for-loop entirely, meaning the `switchFallthrough` assignment never occurs, nor
2274             // does any code after it inside the switch. We've forbidden `continue` statements
2275             // inside switch case-blocks entirely, so we don't need to consider their effect on
2276             // control flow; see the Finalizer in FunctionDefinition::Convert.
2277             this->writeStatement(*sc.statement());
2278             this->finishLine();
2279 
2280             if (index < secondToLastCaseIndex) {
2281                 // Set a variable to indicate falling through to the next block. The very last
2282                 // case-block is reached by process of elimination and doesn't need this
2283                 // variable, so we don't actually need to set it if we are on the second-to-last
2284                 // case block.
2285                 this->write(fallthroughVar);
2286                 this->write(" = true;  ");
2287             }
2288             this->writeLine("// fallthrough");
2289 
2290             fIndentation--;
2291             this->writeLine("}");
2292         } else {
2293             // This is the final case. Since it's always last, we can just dump in the code.
2294             // (If we didn't match any of the other values, we must have matched this one by
2295             // process of elimination. If we did match one of the other values, we either hit a
2296             // `break` statement earlier--and won't get this far--or we're falling through.)
2297             this->writeStatement(*sc.statement());
2298             this->finishLine();
2299         }
2300     }
2301 
2302     --fIndentation;
2303     this->writeLine("}");
2304 }
2305 
writeSwitchStatement(const SwitchStatement & s)2306 void WGSLCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
2307     // WGSL supports the `switch` statement in a limited capacity. A default case must always be
2308     // specified. Each switch-case must be scoped inside braces. Fallthrough is not supported; a
2309     // trailing break is implied at the end of each switch-case block. (Explicit breaks are also
2310     // allowed.)  One minor improvement over a traditional switch is that switch-cases take a list
2311     // of values to match, instead of a single value:
2312     //   case 1, 2       { foo(); }
2313     //   case 3, default { bar(); }
2314     //
2315     // We will use the native WGSL switch statement for any switch-cases in the SkSL which can be
2316     // made to conform to these limitations. The remaining cases which cannot conform will be
2317     // emulated with if-else blocks (similar to our GLSL ES2 switch-statement emulation path). This
2318     // should give us good performance in the common case, since most switches naturally conform.
2319 
2320     // First, let's emit the switch itself.
2321     std::string valueExpr = this->writeNontrivialScratchLet(*s.value(), Precedence::kExpression);
2322     this->write("switch ");
2323     this->write(valueExpr);
2324     this->writeLine(" {");
2325     ++fIndentation;
2326 
2327     // Now let's go through the switch-cases, and emit the ones that don't fall through.
2328     TArray<const SwitchCase*> nativeCases;
2329     TArray<const SwitchCase*> fallthroughCases;
2330     bool previousCaseFellThrough = false;
2331     bool foundNativeDefault = false;
2332     [[maybe_unused]] bool foundFallthroughDefault = false;
2333 
2334     const int lastSwitchCaseIdx = s.cases().size() - 1;
2335     for (int index = 0; index <= lastSwitchCaseIdx; ++index) {
2336         const SwitchCase& sc = s.cases()[index]->as<SwitchCase>();
2337 
2338         if (sc.statement()->isEmpty()) {
2339             // This is a `case X:` that immediately falls through to the next case.
2340             // If we aren't already falling through, we can handle this via a comma-separated list.
2341             if (previousCaseFellThrough) {
2342                 fallthroughCases.push_back(&sc);
2343                 foundFallthroughDefault |= sc.isDefault();
2344             } else {
2345                 nativeCases.push_back(&sc);
2346                 foundNativeDefault |= sc.isDefault();
2347             }
2348             continue;
2349         }
2350 
2351         if (index == lastSwitchCaseIdx || Analysis::SwitchCaseContainsUnconditionalExit(sc)) {
2352             // This is a `case X:` that never falls through.
2353             if (previousCaseFellThrough) {
2354                 // Because the previous cases fell through, we can't use a native switch-case here.
2355                 fallthroughCases.push_back(&sc);
2356                 foundFallthroughDefault |= sc.isDefault();
2357 
2358                 this->writeEmulatedSwitchFallthroughCases(fallthroughCases, valueExpr);
2359                 fallthroughCases.clear();
2360 
2361                 // Fortunately, we're no longer falling through blocks, so we might be able to use a
2362                 // native switch-case list again.
2363                 previousCaseFellThrough = false;
2364             } else {
2365                 // Emit a native switch-case block with a comma-separated case list.
2366                 nativeCases.push_back(&sc);
2367                 foundNativeDefault |= sc.isDefault();
2368 
2369                 this->writeSwitchCases(nativeCases);
2370                 nativeCases.clear();
2371             }
2372             continue;
2373         }
2374 
2375         // This case falls through, so it will need to be handled via emulation.
2376         // If we have put together a collection of "native" cases (cases that fall through with no
2377         // actual case-body), we will need to slide them over into the fallthrough-case list.
2378         fallthroughCases.push_back_n(nativeCases.size(), nativeCases.data());
2379         nativeCases.clear();
2380 
2381         fallthroughCases.push_back(&sc);
2382         foundFallthroughDefault |= sc.isDefault();
2383         previousCaseFellThrough = true;
2384     }
2385 
2386     // Finish out the remaining switch-cases.
2387     this->writeSwitchCases(nativeCases);
2388     nativeCases.clear();
2389 
2390     this->writeEmulatedSwitchFallthroughCases(fallthroughCases, valueExpr);
2391     fallthroughCases.clear();
2392 
2393     // WGSL requires a default case.
2394     if (!foundNativeDefault && !foundFallthroughDefault) {
2395         this->writeLine("case default {}");
2396     }
2397 
2398     --fIndentation;
2399     this->writeLine("}");
2400 }
2401 
writeVarDeclaration(const VarDeclaration & varDecl)2402 void WGSLCodeGenerator::writeVarDeclaration(const VarDeclaration& varDecl) {
2403     std::string initialValue =
2404             varDecl.value() ? this->assembleExpression(*varDecl.value(), Precedence::kAssignment)
2405                             : std::string();
2406 
2407     if (varDecl.var()->modifierFlags().isConst() ||
2408         (fProgram.fUsage->get(*varDecl.var()).fWrite == 1 && varDecl.value())) {
2409         // Use `const` at global scope, or if the value is a compile-time constant.
2410         SkASSERTF(varDecl.value(), "an immutable variable must specify a value");
2411         const bool useConst =
2412                 !fAtFunctionScope || Analysis::IsCompileTimeConstant(*varDecl.value());
2413         this->write(useConst ? "const " : "let ");
2414     } else {
2415         this->write("var ");
2416     }
2417     this->write(this->assembleName(varDecl.var()->mangledName()));
2418     this->write(": ");
2419     this->write(to_wgsl_type(fContext, varDecl.var()->type(), &varDecl.var()->layout()));
2420 
2421     if (varDecl.value()) {
2422         this->write(" = ");
2423         this->write(initialValue);
2424     }
2425 
2426     this->write(";");
2427 }
2428 
makeLValue(const Expression & e)2429 std::unique_ptr<WGSLCodeGenerator::LValue> WGSLCodeGenerator::makeLValue(const Expression& e) {
2430     if (e.is<VariableReference>()) {
2431         return std::make_unique<PointerLValue>(
2432                 this->variableReferenceNameForLValue(e.as<VariableReference>()));
2433     }
2434     if (e.is<FieldAccess>()) {
2435         return std::make_unique<PointerLValue>(this->assembleFieldAccess(e.as<FieldAccess>()));
2436     }
2437     if (e.is<IndexExpression>()) {
2438         const IndexExpression& idx = e.as<IndexExpression>();
2439         if (idx.base()->type().isVector()) {
2440             // Rewrite indexed-swizzle accesses like `myVec.zyx[i]` into an index onto `myVec`.
2441             if (std::unique_ptr<Expression> rewrite =
2442                         Transform::RewriteIndexedSwizzle(fContext, idx)) {
2443                 return std::make_unique<VectorComponentLValue>(
2444                         this->assembleExpression(*rewrite, Precedence::kAssignment));
2445             } else {
2446                 return std::make_unique<VectorComponentLValue>(this->assembleIndexExpression(idx));
2447             }
2448         } else {
2449             return std::make_unique<PointerLValue>(this->assembleIndexExpression(idx));
2450         }
2451     }
2452     if (e.is<Swizzle>()) {
2453         const Swizzle& swizzle = e.as<Swizzle>();
2454         if (swizzle.components().size() == 1) {
2455             return std::make_unique<VectorComponentLValue>(this->assembleSwizzle(swizzle));
2456         } else {
2457             return std::make_unique<SwizzleLValue>(
2458                     fContext,
2459                     this->assembleExpression(*swizzle.base(), Precedence::kAssignment),
2460                     swizzle.base()->type(),
2461                     swizzle.components());
2462         }
2463     }
2464 
2465     fContext.fErrors->error(e.fPosition, "unsupported lvalue type");
2466     return nullptr;
2467 }
2468 
assembleExpression(const Expression & e,Precedence parentPrecedence)2469 std::string WGSLCodeGenerator::assembleExpression(const Expression& e,
2470                                                   Precedence parentPrecedence) {
2471     switch (e.kind()) {
2472         case Expression::Kind::kBinary:
2473             return this->assembleBinaryExpression(e.as<BinaryExpression>(), parentPrecedence);
2474 
2475         case Expression::Kind::kConstructorCompound:
2476             return this->assembleConstructorCompound(e.as<ConstructorCompound>());
2477 
2478         case Expression::Kind::kConstructorArrayCast:
2479             // This is a no-op, since WGSL 1.0 doesn't have any concept of precision qualifiers.
2480             // When we add support for f16, this will need to copy the array contents.
2481             return this->assembleExpression(*e.as<ConstructorArrayCast>().argument(),
2482                                             parentPrecedence);
2483 
2484         case Expression::Kind::kConstructorArray:
2485         case Expression::Kind::kConstructorCompoundCast:
2486         case Expression::Kind::kConstructorScalarCast:
2487         case Expression::Kind::kConstructorSplat:
2488         case Expression::Kind::kConstructorStruct:
2489             return this->assembleAnyConstructor(e.asAnyConstructor());
2490 
2491         case Expression::Kind::kConstructorDiagonalMatrix:
2492             return this->assembleConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>());
2493 
2494         case Expression::Kind::kConstructorMatrixResize:
2495             return this->assembleConstructorMatrixResize(e.as<ConstructorMatrixResize>());
2496 
2497         case Expression::Kind::kEmpty:
2498             return "false";
2499 
2500         case Expression::Kind::kFieldAccess:
2501             return this->assembleFieldAccess(e.as<FieldAccess>());
2502 
2503         case Expression::Kind::kFunctionCall:
2504             return this->assembleFunctionCall(e.as<FunctionCall>(), parentPrecedence);
2505 
2506         case Expression::Kind::kIndex:
2507             return this->assembleIndexExpression(e.as<IndexExpression>());
2508 
2509         case Expression::Kind::kLiteral:
2510             return this->assembleLiteral(e.as<Literal>());
2511 
2512         case Expression::Kind::kPrefix:
2513             return this->assemblePrefixExpression(e.as<PrefixExpression>(), parentPrecedence);
2514 
2515         case Expression::Kind::kPostfix:
2516             return this->assemblePostfixExpression(e.as<PostfixExpression>(), parentPrecedence);
2517 
2518         case Expression::Kind::kSetting:
2519             return this->assembleExpression(*e.as<Setting>().toLiteral(fCaps), parentPrecedence);
2520 
2521         case Expression::Kind::kSwizzle:
2522             return this->assembleSwizzle(e.as<Swizzle>());
2523 
2524         case Expression::Kind::kTernary:
2525             return this->assembleTernaryExpression(e.as<TernaryExpression>(), parentPrecedence);
2526 
2527         case Expression::Kind::kVariableReference:
2528             return this->assembleVariableReference(e.as<VariableReference>());
2529 
2530         default:
2531             SkDEBUGFAILF("unsupported expression:\n%s", e.description().c_str());
2532             return {};
2533     }
2534 }
2535 
is_nontrivial_expression(const Expression & expr)2536 static bool is_nontrivial_expression(const Expression& expr) {
2537     // We consider a "trivial expression" one which we can repeat multiple times in the output
2538     // without being dangerous or spammy. We avoid emitting temporary variables for very trivial
2539     // expressions: literals, unadorned variable references, or constant vectors.
2540     if (expr.is<VariableReference>() || expr.is<Literal>()) {
2541         // Variables and literals are trivial; adding a let-declaration won't simplify anything.
2542         return false;
2543     }
2544     if (expr.type().isVector() && Analysis::IsConstantExpression(expr)) {
2545         // Compile-time constant vectors are also considered trivial; they're short and sweet.
2546         return false;
2547     }
2548     return true;
2549 }
2550 
binary_op_is_ambiguous_in_wgsl(Operator op)2551 static bool binary_op_is_ambiguous_in_wgsl(Operator op) {
2552     // WGSL always requires parentheses for some operators which are deemed to be ambiguous.
2553     // (8.19. Operator Precedence and Associativity)
2554     switch (op.kind()) {
2555         case OperatorKind::LOGICALOR:
2556         case OperatorKind::LOGICALAND:
2557         case OperatorKind::BITWISEOR:
2558         case OperatorKind::BITWISEAND:
2559         case OperatorKind::BITWISEXOR:
2560         case OperatorKind::SHL:
2561         case OperatorKind::SHR:
2562         case OperatorKind::LT:
2563         case OperatorKind::GT:
2564         case OperatorKind::LTEQ:
2565         case OperatorKind::GTEQ:
2566             return true;
2567 
2568         default:
2569             return false;
2570     }
2571 }
2572 
binaryOpNeedsComponentwiseMatrixPolyfill(const Type & left,const Type & right,Operator op)2573 bool WGSLCodeGenerator::binaryOpNeedsComponentwiseMatrixPolyfill(const Type& left,
2574                                                                  const Type& right,
2575                                                                  Operator op) {
2576     switch (op.kind()) {
2577         case OperatorKind::SLASH:
2578             // WGSL does not natively support componentwise matrix-op-matrix for division.
2579             if (left.isMatrix() && right.isMatrix()) {
2580                 return true;
2581             }
2582             [[fallthrough]];
2583 
2584         case OperatorKind::PLUS:
2585         case OperatorKind::MINUS:
2586             // WGSL does not natively support componentwise matrix-op-scalar or scalar-op-matrix for
2587             // addition, subtraction or division.
2588             return (left.isMatrix() && right.isScalar()) ||
2589                    (left.isScalar() && right.isMatrix());
2590 
2591         default:
2592             return false;
2593     }
2594 }
2595 
assembleBinaryExpression(const BinaryExpression & b,Precedence parentPrecedence)2596 std::string WGSLCodeGenerator::assembleBinaryExpression(const BinaryExpression& b,
2597                                                         Precedence parentPrecedence) {
2598     return this->assembleBinaryExpression(*b.left(), b.getOperator(), *b.right(), b.type(),
2599                                           parentPrecedence);
2600 }
2601 
assembleBinaryExpression(const Expression & left,Operator op,const Expression & right,const Type & resultType,Precedence parentPrecedence)2602 std::string WGSLCodeGenerator::assembleBinaryExpression(const Expression& left,
2603                                                         Operator op,
2604                                                         const Expression& right,
2605                                                         const Type& resultType,
2606                                                         Precedence parentPrecedence) {
2607     // If the operator is && or ||, we need to handle short-circuiting properly. Specifically, we
2608     // sometimes need to emit extra statements to paper over functionality that WGSL lacks, like
2609     // assignment in the middle of an expression. We need to guard those extra statements, to ensure
2610     // that they don't occur if the expression evaluation is short-circuited. Converting the
2611     // expression into an if-else block keeps the short-circuit property intact even when extra
2612     // statements are involved.
2613     // If the RHS doesn't have any side effects, then it's safe to just leave the expression as-is,
2614     // since we know that any possible extra statements are non-side-effecting.
2615     std::string expr;
2616     if (op.kind() == OperatorKind::LOGICALAND && Analysis::HasSideEffects(right)) {
2617         // Converts `left_expression && right_expression` into the following block:
2618 
2619         // var _skTemp1: bool;
2620         // [[ prepare left_expression ]]
2621         // if left_expression {
2622         //     [[ prepare right_expression ]]
2623         //     _skTemp1 = right_expression;
2624         // } else {
2625         //     _skTemp1 = false;
2626         // }
2627 
2628         expr = this->writeScratchVar(resultType);
2629 
2630         std::string leftExpr = this->assembleExpression(left, Precedence::kExpression);
2631         this->write("if ");
2632         this->write(leftExpr);
2633         this->writeLine(" {");
2634 
2635         ++fIndentation;
2636         std::string rightExpr = this->assembleExpression(right, Precedence::kAssignment);
2637         this->write(expr);
2638         this->write(" = ");
2639         this->write(rightExpr);
2640         this->writeLine(";");
2641         --fIndentation;
2642 
2643         this->writeLine("} else {");
2644 
2645         ++fIndentation;
2646         this->write(expr);
2647         this->writeLine(" = false;");
2648         --fIndentation;
2649 
2650         this->writeLine("}");
2651         return expr;
2652     }
2653 
2654     if (op.kind() == OperatorKind::LOGICALOR && Analysis::HasSideEffects(right)) {
2655         // Converts `left_expression || right_expression` into the following block:
2656 
2657         // var _skTemp1: bool;
2658         // [[ prepare left_expression ]]
2659         // if left_expression {
2660         //     _skTemp1 = true;
2661         // } else {
2662         //     [[ prepare right_expression ]]
2663         //     _skTemp1 = right_expression;
2664         // }
2665 
2666         expr = this->writeScratchVar(resultType);
2667 
2668         std::string leftExpr = this->assembleExpression(left, Precedence::kExpression);
2669         this->write("if ");
2670         this->write(leftExpr);
2671         this->writeLine(" {");
2672 
2673         ++fIndentation;
2674         this->write(expr);
2675         this->writeLine(" = true;");
2676         --fIndentation;
2677 
2678         this->writeLine("} else {");
2679 
2680         ++fIndentation;
2681         std::string rightExpr = this->assembleExpression(right, Precedence::kAssignment);
2682         this->write(expr);
2683         this->write(" = ");
2684         this->write(rightExpr);
2685         this->writeLine(";");
2686         --fIndentation;
2687 
2688         this->writeLine("}");
2689         return expr;
2690     }
2691 
2692     // Handle comma-expressions.
2693     if (op.kind() == OperatorKind::COMMA) {
2694         // The result from the left-expression is ignored, but its side effects must occur.
2695         this->assembleExpression(left, Precedence::kStatement);
2696 
2697         // Evaluate the right side normally.
2698         return this->assembleExpression(right, parentPrecedence);
2699     }
2700 
2701     // Handle assignment-expressions.
2702     if (op.isAssignment()) {
2703         std::unique_ptr<LValue> lvalue = this->makeLValue(left);
2704         if (!lvalue) {
2705             return "";
2706         }
2707 
2708         if (op.kind() == OperatorKind::EQ) {
2709             // Evaluate the right-hand side of simple assignment (`a = b` --> `b`).
2710             expr = this->assembleExpression(right, Precedence::kAssignment);
2711         } else {
2712             // Evaluate the right-hand side of compound-assignment (`a += b` --> `a + b`).
2713             op = op.removeAssignment();
2714 
2715             std::string lhs = lvalue->load();
2716             std::string rhs = this->assembleExpression(right, op.getBinaryPrecedence());
2717 
2718             if (this->binaryOpNeedsComponentwiseMatrixPolyfill(left.type(), right.type(), op)) {
2719                 if (is_nontrivial_expression(right)) {
2720                     rhs = this->writeScratchLet(rhs);
2721                 }
2722 
2723                 expr = this->assembleComponentwiseMatrixBinary(left.type(), right.type(),
2724                                                                lhs, rhs, op);
2725             } else {
2726                 expr = lhs + operator_name(op) + rhs;
2727             }
2728         }
2729 
2730         // Emit the assignment statement (`a = a + b`).
2731         this->writeLine(lvalue->store(expr));
2732 
2733         // Return the lvalue (`a`) as the result, since the value might be used by the caller.
2734         return lvalue->load();
2735     }
2736 
2737     if (op.isEquality()) {
2738         return this->assembleEqualityExpression(left, right, op, parentPrecedence);
2739     }
2740 
2741     Precedence precedence = op.getBinaryPrecedence();
2742     bool needParens = precedence >= parentPrecedence;
2743     if (binary_op_is_ambiguous_in_wgsl(op)) {
2744         precedence = Precedence::kParentheses;
2745     }
2746     if (needParens) {
2747         expr = "(";
2748     }
2749 
2750     // If we are emitting `constant + constant`, this generally indicates that the values could not
2751     // be constant-folded. This happens when the values overflow or become nan. WGSL will refuse to
2752     // compile such expressions, as WGSL 1.0 has no infinity/nan support. However, the WGSL
2753     // compile-time check can be dodged by putting one side into a let-variable. This technically
2754     // gives us an indeterminate result, but the vast majority of backends will just calculate an
2755     // infinity or nan here, as we would expect. (skia:14385)
2756     bool bothSidesConstant = ConstantFolder::GetConstantValueOrNull(left) &&
2757                              ConstantFolder::GetConstantValueOrNull(right);
2758 
2759     std::string lhs = this->assembleExpression(left, precedence);
2760     std::string rhs = this->assembleExpression(right, precedence);
2761 
2762     if (this->binaryOpNeedsComponentwiseMatrixPolyfill(left.type(), right.type(), op)) {
2763         if (bothSidesConstant || is_nontrivial_expression(left)) {
2764             lhs = this->writeScratchLet(lhs);
2765         }
2766         if (is_nontrivial_expression(right)) {
2767             rhs = this->writeScratchLet(rhs);
2768         }
2769 
2770         expr += this->assembleComponentwiseMatrixBinary(left.type(), right.type(), lhs, rhs, op);
2771     } else {
2772         if (bothSidesConstant) {
2773             lhs = this->writeScratchLet(lhs);
2774         }
2775 
2776         expr += lhs + operator_name(op) + rhs;
2777     }
2778 
2779     if (needParens) {
2780         expr += ')';
2781     }
2782 
2783     return expr;
2784 }
2785 
assembleFieldAccess(const FieldAccess & f)2786 std::string WGSLCodeGenerator::assembleFieldAccess(const FieldAccess& f) {
2787     const Field* field = &f.base()->type().fields()[f.fieldIndex()];
2788     std::string expr;
2789 
2790     if (FieldPolyfillInfo* polyfillInfo = fFieldPolyfillMap.find(field)) {
2791         // We found a matrix uniform. We are required to pass some matrix uniforms as array vectors,
2792         // since the std140 layout for a matrix assumes 4-column vectors for each row, and WGSL
2793         // tightly packs 2-column matrices. When emitting code, we replace the field-access
2794         // expression with a global variable which holds an unpacked version of the uniform.
2795         polyfillInfo->fWasAccessed = true;
2796 
2797         // The polyfill can either be based directly onto a uniform in an interface block, or it
2798         // might be based on an index-expression onto a uniform if the interface block is arrayed.
2799         const Expression* base = f.base().get();
2800         const IndexExpression* indexExpr = nullptr;
2801         if (base->is<IndexExpression>()) {
2802             indexExpr = &base->as<IndexExpression>();
2803             base = indexExpr->base().get();
2804         }
2805 
2806         SkASSERT(base->is<VariableReference>());
2807         expr = polyfillInfo->fReplacementName;
2808 
2809         // If we had an index expression, we must append the index.
2810         if (indexExpr) {
2811             expr += '[';
2812             expr += this->assembleExpression(*indexExpr->index(), Precedence::kSequence);
2813             expr += ']';
2814         }
2815         return expr;
2816     }
2817 
2818     switch (f.ownerKind()) {
2819         case FieldAccess::OwnerKind::kDefault:
2820             expr = this->assembleExpression(*f.base(), Precedence::kPostfix) + '.';
2821             break;
2822 
2823         case FieldAccess::OwnerKind::kAnonymousInterfaceBlock:
2824             if (f.base()->is<VariableReference>() &&
2825                 field->fLayout.fBuiltin != SK_POINTSIZE_BUILTIN) {
2826                 expr = this->variablePrefix(*f.base()->as<VariableReference>().variable());
2827             }
2828             break;
2829     }
2830 
2831     expr += this->assembleName(field->fName);
2832     return expr;
2833 }
2834 
all_arguments_constant(const ExpressionArray & arguments)2835 static bool all_arguments_constant(const ExpressionArray& arguments) {
2836     // Returns true if all arguments in the ExpressionArray are compile-time constants. If we are
2837     // calling an intrinsic and all of its inputs are constant, but we didn't constant-fold it, this
2838     // generally indicates that constant-folding resulted in an infinity or nan. The WGSL compiler
2839     // will reject such an expression with a compile-time error. We can dodge the error, taking on
2840     // the risk of indeterminate behavior instead, by replacing one of the constant values with a
2841     // scratch let-variable. (skia:14385)
2842     for (const std::unique_ptr<Expression>& arg : arguments) {
2843         if (!ConstantFolder::GetConstantValueOrNull(*arg)) {
2844             return false;
2845         }
2846     }
2847     return true;
2848 }
2849 
assembleSimpleIntrinsic(std::string_view intrinsicName,const FunctionCall & call)2850 std::string WGSLCodeGenerator::assembleSimpleIntrinsic(std::string_view intrinsicName,
2851                                                        const FunctionCall& call) {
2852     // Invoke the function, passing each function argument.
2853     std::string expr = std::string(intrinsicName);
2854     expr.push_back('(');
2855     const ExpressionArray& args = call.arguments();
2856     auto separator = SkSL::String::Separator();
2857     bool allConstant = all_arguments_constant(call.arguments());
2858     for (int index = 0; index < args.size(); ++index) {
2859         expr += separator();
2860 
2861         std::string argument = this->assembleExpression(*args[index], Precedence::kSequence);
2862         if (args[index]->type().isAtomic()) {
2863             // WGSL passes atomic values to intrinsics as pointers.
2864             expr += '&';
2865             expr += argument;
2866         } else if (allConstant && index == 0) {
2867             // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2868             expr += this->writeScratchLet(argument);
2869         } else {
2870             expr += argument;
2871         }
2872     }
2873     expr.push_back(')');
2874 
2875     if (call.type().isVoid()) {
2876         this->write(expr);
2877         this->writeLine(";");
2878         return std::string();
2879     } else {
2880         return this->writeScratchLet(expr);
2881     }
2882 }
2883 
assembleVectorizedIntrinsic(std::string_view intrinsicName,const FunctionCall & call)2884 std::string WGSLCodeGenerator::assembleVectorizedIntrinsic(std::string_view intrinsicName,
2885                                                            const FunctionCall& call) {
2886     SkASSERT(!call.type().isVoid());
2887 
2888     // Invoke the function, passing each function argument.
2889     std::string expr = std::string(intrinsicName);
2890     expr.push_back('(');
2891 
2892     auto separator = SkSL::String::Separator();
2893     const ExpressionArray& args = call.arguments();
2894     bool returnsVector = call.type().isVector();
2895     bool allConstant = all_arguments_constant(call.arguments());
2896     for (int index = 0; index < args.size(); ++index) {
2897         expr += separator();
2898 
2899         bool vectorize = returnsVector && args[index]->type().isScalar();
2900         if (vectorize) {
2901             expr += to_wgsl_type(fContext, call.type());
2902             expr.push_back('(');
2903         }
2904 
2905         // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2906         std::string argument = this->assembleExpression(*args[index], Precedence::kSequence);
2907         expr += (allConstant && index == 0) ? this->writeScratchLet(argument)
2908                                             : argument;
2909         if (vectorize) {
2910             expr.push_back(')');
2911         }
2912     }
2913     expr.push_back(')');
2914 
2915     return this->writeScratchLet(expr);
2916 }
2917 
assembleUnaryOpIntrinsic(Operator op,const FunctionCall & call,Precedence parentPrecedence)2918 std::string WGSLCodeGenerator::assembleUnaryOpIntrinsic(Operator op,
2919                                                         const FunctionCall& call,
2920                                                         Precedence parentPrecedence) {
2921     SkASSERT(!call.type().isVoid());
2922 
2923     bool needParens = Precedence::kPrefix >= parentPrecedence;
2924 
2925     std::string expr;
2926     if (needParens) {
2927         expr.push_back('(');
2928     }
2929 
2930     expr += operator_name(op);
2931     expr += this->assembleExpression(*call.arguments()[0], Precedence::kPrefix);
2932 
2933     if (needParens) {
2934         expr.push_back(')');
2935     }
2936 
2937     return expr;
2938 }
2939 
assembleBinaryOpIntrinsic(Operator op,const FunctionCall & call,Precedence parentPrecedence)2940 std::string WGSLCodeGenerator::assembleBinaryOpIntrinsic(Operator op,
2941                                                          const FunctionCall& call,
2942                                                          Precedence parentPrecedence) {
2943     SkASSERT(!call.type().isVoid());
2944 
2945     Precedence precedence = op.getBinaryPrecedence();
2946     bool needParens = precedence >= parentPrecedence ||
2947                       binary_op_is_ambiguous_in_wgsl(op);
2948     std::string expr;
2949     if (needParens) {
2950         expr.push_back('(');
2951     }
2952 
2953     // We can use a scratch-let for argument 0 to dodge WGSL overflow errors. (skia:14385)
2954     std::string argument = this->assembleExpression(*call.arguments()[0], precedence);
2955     expr += all_arguments_constant(call.arguments()) ? this->writeScratchLet(argument)
2956                                                      : argument;
2957     expr += operator_name(op);
2958     expr += this->assembleExpression(*call.arguments()[1], precedence);
2959 
2960     if (needParens) {
2961         expr.push_back(')');
2962     }
2963 
2964     return expr;
2965 }
2966 
2967 // Rewrite a WGSL intrinsic of the form "intrinsicName(in) -> struct" to the SkSL's
2968 // "intrinsicName(in, outField) -> returnField", where outField and returnField are the names of the
2969 // fields in the struct returned by the WGSL intrinsic.
assembleOutAssignedIntrinsic(std::string_view intrinsicName,std::string_view returnField,std::string_view outField,const FunctionCall & call)2970 std::string WGSLCodeGenerator::assembleOutAssignedIntrinsic(std::string_view intrinsicName,
2971                                                             std::string_view returnField,
2972                                                             std::string_view outField,
2973                                                             const FunctionCall& call) {
2974     SkASSERT(call.type().componentType().isNumber());
2975     SkASSERT(call.arguments().size() == 2);
2976     SkASSERT(call.function().parameters()[1]->modifierFlags() & ModifierFlag::kOut);
2977 
2978     std::string expr = std::string(intrinsicName);
2979     expr += "(";
2980 
2981     // Invoke the intrinsic with the first parameter. Use a scratch-let if argument is a constant
2982     // to dodge WGSL overflow errors. (skia:14385)
2983     std::string argument = this->assembleExpression(*call.arguments()[0], Precedence::kSequence);
2984     expr += ConstantFolder::GetConstantValueOrNull(*call.arguments()[0])
2985             ? this->writeScratchLet(argument) : argument;
2986     expr += ")";
2987     // In WGSL the intrinsic returns a struct; assign it to a local so that its fields can be
2988     // accessed multiple times.
2989     expr = this->writeScratchLet(expr);
2990     expr += ".";
2991 
2992     // Store the outField of `expr` to the intended "out" argument
2993     std::unique_ptr<LValue> lvalue = this->makeLValue(*call.arguments()[1]);
2994     if (!lvalue) {
2995         return "";
2996     }
2997     std::string outValue = expr;
2998     outValue += outField;
2999     this->writeLine(lvalue->store(outValue));
3000 
3001     // And return the expression accessing the returnField.
3002     expr += returnField;
3003     return expr;
3004 }
3005 
assemblePartialSampleCall(std::string_view functionName,const Expression & sampler,const Expression & coords)3006 std::string WGSLCodeGenerator::assemblePartialSampleCall(std::string_view functionName,
3007                                                          const Expression& sampler,
3008                                                          const Expression& coords) {
3009     // This function returns `functionName(inSampler_texture, inSampler_sampler, coords` without a
3010     // terminating comma or close-parenthesis. This allows the caller to add more arguments as
3011     // needed.
3012     SkASSERT(sampler.type().typeKind() == Type::TypeKind::kSampler);
3013     std::string expr = std::string(functionName) + '(';
3014     expr += this->assembleExpression(sampler, Precedence::kSequence);
3015     expr += kTextureSuffix;
3016     expr += ", ";
3017     expr += this->assembleExpression(sampler, Precedence::kSequence);
3018     expr += kSamplerSuffix;
3019     expr += ", ";
3020 
3021     // Compute the sample coordinates, dividing out the Z if a vec3 was provided.
3022     SkASSERT(coords.type().isVector());
3023     if (coords.type().columns() == 3) {
3024         // The coordinates were passed as a vec3, so we need to emit `coords.xy / coords.z`.
3025         std::string vec3Coords = this->writeScratchLet(coords, Precedence::kMultiplicative);
3026         expr += vec3Coords + ".xy / " + vec3Coords + ".z";
3027     } else {
3028         // The coordinates should be a plain vec2; emit the expression as-is.
3029         SkASSERT(coords.type().columns() == 2);
3030         expr += this->assembleExpression(coords, Precedence::kSequence);
3031     }
3032 
3033     return expr;
3034 }
3035 
assembleComponentwiseMatrixBinary(const Type & leftType,const Type & rightType,const std::string & left,const std::string & right,Operator op)3036 std::string WGSLCodeGenerator::assembleComponentwiseMatrixBinary(const Type& leftType,
3037                                                                  const Type& rightType,
3038                                                                  const std::string& left,
3039                                                                  const std::string& right,
3040                                                                  Operator op) {
3041     bool leftIsMatrix = leftType.isMatrix();
3042     bool rightIsMatrix = rightType.isMatrix();
3043     const Type& matrixType = leftIsMatrix ? leftType : rightType;
3044 
3045     std::string expr = to_wgsl_type(fContext, matrixType) + '(';
3046     auto separator = String::Separator();
3047     int columns = matrixType.columns();
3048     for (int c = 0; c < columns; ++c) {
3049         expr += separator();
3050         expr += left;
3051         if (leftIsMatrix) {
3052             expr += '[';
3053             expr += std::to_string(c);
3054             expr += ']';
3055         }
3056         expr += op.operatorName();
3057         expr += right;
3058         if (rightIsMatrix) {
3059             expr += '[';
3060             expr += std::to_string(c);
3061             expr += ']';
3062         }
3063     }
3064     return expr + ')';
3065 }
3066 
assembleIntrinsicCall(const FunctionCall & call,IntrinsicKind kind,Precedence parentPrecedence)3067 std::string WGSLCodeGenerator::assembleIntrinsicCall(const FunctionCall& call,
3068                                                      IntrinsicKind kind,
3069                                                      Precedence parentPrecedence) {
3070     // Be careful: WGSL 1.0 will reject any intrinsic calls which can be constant-evaluated to
3071     // infinity or nan with a compile error. If all arguments to an intrinsic are compile-time
3072     // constants (`all_arguments_constant`), it is safest to copy one argument into a scratch-let so
3073     // that the call will be seen as runtime-evaluated, which defuses the overflow checks.
3074     // Don't worry; a competent driver should still optimize it away.
3075 
3076     const ExpressionArray& arguments = call.arguments();
3077     switch (kind) {
3078         case k_atan_IntrinsicKind: {
3079             const char* name = (arguments.size() == 1) ? "atan" : "atan2";
3080             return this->assembleSimpleIntrinsic(name, call);
3081         }
3082         case k_dFdx_IntrinsicKind:
3083             return this->assembleSimpleIntrinsic("dpdx", call);
3084 
3085         case k_dFdy_IntrinsicKind:
3086             // TODO(b/294274678): apply RTFlip here
3087             return this->assembleSimpleIntrinsic("dpdy", call);
3088 
3089         case k_dot_IntrinsicKind: {
3090             if (arguments[0]->type().isScalar()) {
3091                 return this->assembleBinaryOpIntrinsic(OperatorKind::STAR, call, parentPrecedence);
3092             }
3093             return this->assembleSimpleIntrinsic("dot", call);
3094         }
3095         case k_equal_IntrinsicKind:
3096             return this->assembleBinaryOpIntrinsic(OperatorKind::EQEQ, call, parentPrecedence);
3097 
3098         case k_faceforward_IntrinsicKind: {
3099             if (arguments[0]->type().isScalar()) {
3100                 // select(-N, N, (I * Nref) < 0)
3101                 std::string N = this->writeNontrivialScratchLet(*arguments[0],
3102                                                                 Precedence::kAssignment);
3103                 return this->writeScratchLet(
3104                         "select(-" + N + ", " + N + ", " +
3105                         this->assembleBinaryExpression(*arguments[1],
3106                                                        OperatorKind::STAR,
3107                                                        *arguments[2],
3108                                                        arguments[1]->type(),
3109                                                        Precedence::kRelational) +
3110                         " < 0)");
3111             }
3112             return this->assembleSimpleIntrinsic("faceForward", call);
3113         }
3114         case k_frexp_IntrinsicKind:
3115             // SkSL frexp is "$genType fract = frexp($genType, out $genIType exp)" whereas WGSL
3116             // returns a struct with no out param: "let [fract, exp] = frexp($genType)".
3117             return this->assembleOutAssignedIntrinsic("frexp", "fract", "exp", call);
3118 
3119         case k_greaterThan_IntrinsicKind:
3120             return this->assembleBinaryOpIntrinsic(OperatorKind::GT, call, parentPrecedence);
3121 
3122         case k_greaterThanEqual_IntrinsicKind:
3123             return this->assembleBinaryOpIntrinsic(OperatorKind::GTEQ, call, parentPrecedence);
3124 
3125         case k_inverse_IntrinsicKind:
3126             return this->assembleInversePolyfill(call);
3127 
3128         case k_inversesqrt_IntrinsicKind:
3129             return this->assembleSimpleIntrinsic("inverseSqrt", call);
3130 
3131         case k_lessThan_IntrinsicKind:
3132             return this->assembleBinaryOpIntrinsic(OperatorKind::LT, call, parentPrecedence);
3133 
3134         case k_lessThanEqual_IntrinsicKind:
3135             return this->assembleBinaryOpIntrinsic(OperatorKind::LTEQ, call, parentPrecedence);
3136 
3137         case k_matrixCompMult_IntrinsicKind: {
3138             // We use a scratch-let for arg0 to avoid the potential for WGSL overflow. (skia:14385)
3139             std::string arg0 = all_arguments_constant(arguments)
3140                             ? this->writeScratchLet(*arguments[0], Precedence::kPostfix)
3141                             : this->writeNontrivialScratchLet(*arguments[0], Precedence::kPostfix);
3142             std::string arg1 = this->writeNontrivialScratchLet(*arguments[1], Precedence::kPostfix);
3143             return this->writeScratchLet(
3144                     this->assembleComponentwiseMatrixBinary(arguments[0]->type(),
3145                                                             arguments[1]->type(),
3146                                                             arg0,
3147                                                             arg1,
3148                                                             OperatorKind::STAR));
3149         }
3150         case k_mix_IntrinsicKind: {
3151             const char* name = arguments[2]->type().componentType().isBoolean() ? "select" : "mix";
3152             return this->assembleVectorizedIntrinsic(name, call);
3153         }
3154         case k_mod_IntrinsicKind: {
3155             // WGSL has no intrinsic equivalent to `mod`. Synthesize `x - y * floor(x / y)`.
3156             // We can use a scratch-let on one side to dodge WGSL overflow errors.  In practice, I
3157             // can't find any values of x or y which would overflow, but it can't hurt. (skia:14385)
3158             std::string arg0 = all_arguments_constant(arguments)
3159                             ? this->writeScratchLet(*arguments[0], Precedence::kAdditive)
3160                             : this->writeNontrivialScratchLet(*arguments[0], Precedence::kAdditive);
3161             std::string arg1 = this->writeNontrivialScratchLet(*arguments[1],
3162                                                                Precedence::kAdditive);
3163             return this->writeScratchLet(arg0 + " - " + arg1 + " * floor(" +
3164                                          arg0 + " / " + arg1 + ")");
3165         }
3166 
3167         case k_modf_IntrinsicKind:
3168             // SkSL modf is "$genType fract = modf($genType, out $genType whole)" whereas WGSL
3169             // returns a struct with no out param: "let [fract, whole] = modf($genType)".
3170             return this->assembleOutAssignedIntrinsic("modf", "fract", "whole", call);
3171 
3172         case k_normalize_IntrinsicKind: {
3173             const char* name = arguments[0]->type().isScalar() ? "sign" : "normalize";
3174             return this->assembleSimpleIntrinsic(name, call);
3175         }
3176         case k_not_IntrinsicKind:
3177             return this->assembleUnaryOpIntrinsic(OperatorKind::LOGICALNOT, call, parentPrecedence);
3178 
3179         case k_notEqual_IntrinsicKind:
3180             return this->assembleBinaryOpIntrinsic(OperatorKind::NEQ, call, parentPrecedence);
3181 
3182         case k_packHalf2x16_IntrinsicKind:
3183             return this->assembleSimpleIntrinsic("pack2x16float", call);
3184 
3185         case k_packSnorm2x16_IntrinsicKind:
3186             return this->assembleSimpleIntrinsic("pack2x16snorm", call);
3187 
3188         case k_packSnorm4x8_IntrinsicKind:
3189             return this->assembleSimpleIntrinsic("pack4x8snorm", call);
3190 
3191         case k_packUnorm2x16_IntrinsicKind:
3192             return this->assembleSimpleIntrinsic("pack2x16unorm", call);
3193 
3194         case k_packUnorm4x8_IntrinsicKind:
3195             return this->assembleSimpleIntrinsic("pack4x8unorm", call);
3196 
3197         case k_reflect_IntrinsicKind:
3198             if (arguments[0]->type().isScalar()) {
3199                 // I - 2 * N * I * N
3200                 // We can use a scratch-let for N to dodge WGSL overflow errors. (skia:14385)
3201                 std::string I = this->writeNontrivialScratchLet(*arguments[0],
3202                                                                 Precedence::kAdditive);
3203                 std::string N = all_arguments_constant(arguments)
3204                       ? this->writeScratchLet(*arguments[1], Precedence::kMultiplicative)
3205                       : this->writeNontrivialScratchLet(*arguments[1], Precedence::kMultiplicative);
3206                 return this->writeScratchLet(String::printf("%s - 2 * %s * %s * %s",
3207                                                             I.c_str(), N.c_str(),
3208                                                             I.c_str(), N.c_str()));
3209             }
3210             return this->assembleSimpleIntrinsic("reflect", call);
3211 
3212         case k_refract_IntrinsicKind:
3213             if (arguments[0]->type().isScalar()) {
3214                 // WGSL only implements refract for vectors; rather than reimplementing refract from
3215                 // scratch, we can replace the call with `refract(float2(I,0), float2(N,0), eta).x`.
3216                 std::string I = this->writeNontrivialScratchLet(*arguments[0],
3217                                                                 Precedence::kSequence);
3218                 std::string N = this->writeNontrivialScratchLet(*arguments[1],
3219                                                                 Precedence::kSequence);
3220                 // We can use a scratch-let for Eta to avoid WGSL overflow errors. (skia:14385)
3221                 std::string Eta = all_arguments_constant(arguments)
3222                       ? this->writeScratchLet(*arguments[2], Precedence::kSequence)
3223                       : this->writeNontrivialScratchLet(*arguments[2], Precedence::kSequence);
3224                 return this->writeScratchLet(
3225                         String::printf("refract(vec2<%s>(%s, 0), vec2<%s>(%s, 0), %s).x",
3226                                        to_wgsl_type(fContext, arguments[0]->type()).c_str(),
3227                                        I.c_str(),
3228                                        to_wgsl_type(fContext, arguments[1]->type()).c_str(),
3229                                        N.c_str(),
3230                                        Eta.c_str()));
3231             }
3232             return this->assembleSimpleIntrinsic("refract", call);
3233 
3234         case k_sample_IntrinsicKind: {
3235             // Determine if a bias argument was passed in.
3236             SkASSERT(arguments.size() == 2 || arguments.size() == 3);
3237             bool callIncludesBias = (arguments.size() == 3);
3238 
3239             if (fProgram.fConfig->fSettings.fSharpenTextures || callIncludesBias) {
3240                 // We need to supply a bias argument; this is a separate intrinsic in WGSL.
3241                 std::string expr = this->assemblePartialSampleCall("textureSampleBias",
3242                                                                    *arguments[0],
3243                                                                    *arguments[1]);
3244                 expr += ", ";
3245                 if (callIncludesBias) {
3246                     expr += this->assembleExpression(*arguments[2], Precedence::kAdditive) +
3247                             " + ";
3248                 }
3249                 expr += skstd::to_string(fProgram.fConfig->fSettings.fSharpenTextures
3250                                                  ? kSharpenTexturesBias
3251                                                  : 0.0f);
3252                 return expr + ')';
3253             }
3254 
3255             // No bias is necessary, so we can call `textureSample` directly.
3256             return this->assemblePartialSampleCall("textureSample",
3257                                                    *arguments[0],
3258                                                    *arguments[1]) + ')';
3259         }
3260         case k_sampleLod_IntrinsicKind: {
3261             std::string expr = this->assemblePartialSampleCall("textureSampleLevel",
3262                                                                *arguments[0],
3263                                                                *arguments[1]);
3264             expr += ", " + this->assembleExpression(*arguments[2], Precedence::kSequence);
3265             return expr + ')';
3266         }
3267         case k_sampleGrad_IntrinsicKind: {
3268             std::string expr = this->assemblePartialSampleCall("textureSampleGrad",
3269                                                                *arguments[0],
3270                                                                *arguments[1]);
3271             expr += ", " + this->assembleExpression(*arguments[2], Precedence::kSequence);
3272             expr += ", " + this->assembleExpression(*arguments[3], Precedence::kSequence);
3273             return expr + ')';
3274         }
3275         case k_textureHeight_IntrinsicKind:
3276             return this->assembleSimpleIntrinsic("textureDimensions", call) + ".y";
3277 
3278         case k_textureRead_IntrinsicKind: {
3279             // We need to inject an extra argument for the mip-level. We don't plan on using mipmaps
3280             // in our storage textures, so we can just pass zero.
3281             std::string tex = this->assembleExpression(*arguments[0], Precedence::kSequence);
3282             std::string pos = this->writeScratchLet(*arguments[1], Precedence::kSequence);
3283             return std::string("textureLoad(") + tex + ", " + pos + ", 0)";
3284         }
3285         case k_textureWidth_IntrinsicKind:
3286             return this->assembleSimpleIntrinsic("textureDimensions", call) + ".x";
3287 
3288         case k_textureWrite_IntrinsicKind:
3289             return this->assembleSimpleIntrinsic("textureStore", call);
3290 
3291         case k_unpackHalf2x16_IntrinsicKind:
3292             return this->assembleSimpleIntrinsic("unpack2x16float", call);
3293 
3294         case k_unpackSnorm2x16_IntrinsicKind:
3295             return this->assembleSimpleIntrinsic("unpack2x16snorm", call);
3296 
3297         case k_unpackSnorm4x8_IntrinsicKind:
3298             return this->assembleSimpleIntrinsic("unpack4x8snorm", call);
3299 
3300         case k_unpackUnorm2x16_IntrinsicKind:
3301             return this->assembleSimpleIntrinsic("unpack2x16unorm", call);
3302 
3303         case k_unpackUnorm4x8_IntrinsicKind:
3304             return this->assembleSimpleIntrinsic("unpack4x8unorm", call);
3305 
3306         case k_clamp_IntrinsicKind:
3307         case k_max_IntrinsicKind:
3308         case k_min_IntrinsicKind:
3309         case k_smoothstep_IntrinsicKind:
3310         case k_step_IntrinsicKind:
3311             return this->assembleVectorizedIntrinsic(call.function().name(), call);
3312 
3313         case k_abs_IntrinsicKind:
3314         case k_acos_IntrinsicKind:
3315         case k_all_IntrinsicKind:
3316         case k_any_IntrinsicKind:
3317         case k_asin_IntrinsicKind:
3318         case k_atomicAdd_IntrinsicKind:
3319         case k_atomicLoad_IntrinsicKind:
3320         case k_atomicStore_IntrinsicKind:
3321         case k_ceil_IntrinsicKind:
3322         case k_cos_IntrinsicKind:
3323         case k_cross_IntrinsicKind:
3324         case k_degrees_IntrinsicKind:
3325         case k_distance_IntrinsicKind:
3326         case k_exp_IntrinsicKind:
3327         case k_exp2_IntrinsicKind:
3328         case k_floor_IntrinsicKind:
3329         case k_fract_IntrinsicKind:
3330         case k_length_IntrinsicKind:
3331         case k_log_IntrinsicKind:
3332         case k_log2_IntrinsicKind:
3333         case k_radians_IntrinsicKind:
3334         case k_pow_IntrinsicKind:
3335         case k_saturate_IntrinsicKind:
3336         case k_sign_IntrinsicKind:
3337         case k_sin_IntrinsicKind:
3338         case k_sqrt_IntrinsicKind:
3339         case k_storageBarrier_IntrinsicKind:
3340         case k_tan_IntrinsicKind:
3341         case k_workgroupBarrier_IntrinsicKind:
3342         default:
3343             return this->assembleSimpleIntrinsic(call.function().name(), call);
3344     }
3345 }
3346 
3347 static constexpr char kInverse2x2[] =
3348      "fn mat2_inverse(m: mat2x2<f32>) -> mat2x2<f32> {"
3349 "\n"     "return mat2x2<f32>(m[1].y, -m[0].y, -m[1].x, m[0].x) * (1/determinant(m));"
3350 "\n" "}"
3351 "\n";
3352 
3353 static constexpr char kInverse3x3[] =
3354      "fn mat3_inverse(m: mat3x3<f32>) -> mat3x3<f32> {"
3355 "\n"     "let a00 = m[0].x; let a01 = m[0].y; let a02 = m[0].z;"
3356 "\n"     "let a10 = m[1].x; let a11 = m[1].y; let a12 = m[1].z;"
3357 "\n"     "let a20 = m[2].x; let a21 = m[2].y; let a22 = m[2].z;"
3358 "\n"     "let b01 =  a22*a11 - a12*a21;"
3359 "\n"     "let b11 = -a22*a10 + a12*a20;"
3360 "\n"     "let b21 =  a21*a10 - a11*a20;"
3361 "\n"     "let det = a00*b01 + a01*b11 + a02*b21;"
3362 "\n"     "return mat3x3<f32>(b01, (-a22*a01 + a02*a21), ( a12*a01 - a02*a11),"
3363 "\n"                        "b11, ( a22*a00 - a02*a20), (-a12*a00 + a02*a10),"
3364 "\n"                        "b21, (-a21*a00 + a01*a20), ( a11*a00 - a01*a10)) * (1/det);"
3365 "\n" "}"
3366 "\n";
3367 
3368 static constexpr char kInverse4x4[] =
3369      "fn mat4_inverse(m: mat4x4<f32>) -> mat4x4<f32>{"
3370 "\n"     "let a00 = m[0].x; let a01 = m[0].y; let a02 = m[0].z; let a03 = m[0].w;"
3371 "\n"     "let a10 = m[1].x; let a11 = m[1].y; let a12 = m[1].z; let a13 = m[1].w;"
3372 "\n"     "let a20 = m[2].x; let a21 = m[2].y; let a22 = m[2].z; let a23 = m[2].w;"
3373 "\n"     "let a30 = m[3].x; let a31 = m[3].y; let a32 = m[3].z; let a33 = m[3].w;"
3374 "\n"     "let b00 = a00*a11 - a01*a10;"
3375 "\n"     "let b01 = a00*a12 - a02*a10;"
3376 "\n"     "let b02 = a00*a13 - a03*a10;"
3377 "\n"     "let b03 = a01*a12 - a02*a11;"
3378 "\n"     "let b04 = a01*a13 - a03*a11;"
3379 "\n"     "let b05 = a02*a13 - a03*a12;"
3380 "\n"     "let b06 = a20*a31 - a21*a30;"
3381 "\n"     "let b07 = a20*a32 - a22*a30;"
3382 "\n"     "let b08 = a20*a33 - a23*a30;"
3383 "\n"     "let b09 = a21*a32 - a22*a31;"
3384 "\n"     "let b10 = a21*a33 - a23*a31;"
3385 "\n"     "let b11 = a22*a33 - a23*a32;"
3386 "\n"     "let det = b00*b11 - b01*b10 + b02*b09 + b03*b08 - b04*b07 + b05*b06;"
3387 "\n"     "return mat4x4<f32>(a11*b11 - a12*b10 + a13*b09,"
3388 "\n"                        "a02*b10 - a01*b11 - a03*b09,"
3389 "\n"                        "a31*b05 - a32*b04 + a33*b03,"
3390 "\n"                        "a22*b04 - a21*b05 - a23*b03,"
3391 "\n"                        "a12*b08 - a10*b11 - a13*b07,"
3392 "\n"                        "a00*b11 - a02*b08 + a03*b07,"
3393 "\n"                        "a32*b02 - a30*b05 - a33*b01,"
3394 "\n"                        "a20*b05 - a22*b02 + a23*b01,"
3395 "\n"                        "a10*b10 - a11*b08 + a13*b06,"
3396 "\n"                        "a01*b08 - a00*b10 - a03*b06,"
3397 "\n"                        "a30*b04 - a31*b02 + a33*b00,"
3398 "\n"                        "a21*b02 - a20*b04 - a23*b00,"
3399 "\n"                        "a11*b07 - a10*b09 - a12*b06,"
3400 "\n"                        "a00*b09 - a01*b07 + a02*b06,"
3401 "\n"                        "a31*b01 - a30*b03 - a32*b00,"
3402 "\n"                        "a20*b03 - a21*b01 + a22*b00) * (1/det);"
3403 "\n" "}"
3404 "\n";
3405 
assembleInversePolyfill(const FunctionCall & call)3406 std::string WGSLCodeGenerator::assembleInversePolyfill(const FunctionCall& call) {
3407     const ExpressionArray& arguments = call.arguments();
3408     const Type& type = arguments.front()->type();
3409 
3410     // The `inverse` intrinsic should only accept a single-argument square matrix.
3411     // Once we implement f16 support, these polyfills will need to be updated to support `hmat`;
3412     // for the time being, all floats in WGSL are f32, so we don't need to worry about precision.
3413     SkASSERT(arguments.size() == 1);
3414     SkASSERT(type.isMatrix());
3415     SkASSERT(type.rows() == type.columns());
3416 
3417     switch (type.slotCount()) {
3418         case 4:
3419             if (!fWrittenInverse2) {
3420                 fWrittenInverse2 = true;
3421                 fHeader.writeText(kInverse2x2);
3422             }
3423             return this->assembleSimpleIntrinsic("mat2_inverse", call);
3424 
3425         case 9:
3426             if (!fWrittenInverse3) {
3427                 fWrittenInverse3 = true;
3428                 fHeader.writeText(kInverse3x3);
3429             }
3430             return this->assembleSimpleIntrinsic("mat3_inverse", call);
3431 
3432         case 16:
3433             if (!fWrittenInverse4) {
3434                 fWrittenInverse4 = true;
3435                 fHeader.writeText(kInverse4x4);
3436             }
3437             return this->assembleSimpleIntrinsic("mat4_inverse", call);
3438 
3439         default:
3440             // We only support square matrices.
3441             SkUNREACHABLE;
3442     }
3443 }
3444 
assembleFunctionCall(const FunctionCall & call,Precedence parentPrecedence)3445 std::string WGSLCodeGenerator::assembleFunctionCall(const FunctionCall& call,
3446                                                     Precedence parentPrecedence) {
3447     const FunctionDeclaration& func = call.function();
3448     std::string result;
3449 
3450     // Many intrinsics need to be rewritten in WGSL.
3451     if (func.isIntrinsic()) {
3452         return this->assembleIntrinsicCall(call, func.intrinsicKind(), parentPrecedence);
3453     }
3454 
3455     // We implement function out-parameters by declaring them as pointers. SkSL follows GLSL's
3456     // out-parameter semantics, in which out-parameters are only written back to the original
3457     // variable after the function's execution is complete (see
3458     // https://www.khronos.org/opengl/wiki/Core_Language_(GLSL)#Parameters).
3459     //
3460     // In addition, SkSL supports swizzles and array index expressions to be passed into
3461     // out-parameters; however, WGSL does not allow taking their address into a pointer.
3462     //
3463     // We support these by using LValues to create temporary copies and then pass pointers to the
3464     // copies. Once the function returns, we copy the values back to the LValue.
3465 
3466     // First detect which arguments are passed to out-parameters.
3467     // TODO: rewrite this method in terms of LValues.
3468     const ExpressionArray& args = call.arguments();
3469     SkSpan<Variable* const> params = func.parameters();
3470     SkASSERT(SkToSizeT(args.size()) == params.size());
3471 
3472     STArray<16, std::unique_ptr<LValue>> writeback;
3473     STArray<16, std::string> substituteArgument;
3474     writeback.reserve_exact(args.size());
3475     substituteArgument.reserve_exact(args.size());
3476 
3477     for (int index = 0; index < args.size(); ++index) {
3478         if (params[index]->modifierFlags() & ModifierFlag::kOut) {
3479             std::unique_ptr<LValue> lvalue = this->makeLValue(*args[index]);
3480             if (params[index]->modifierFlags() & ModifierFlag::kIn) {
3481                 // Load the lvalue's contents into the substitute argument.
3482                 substituteArgument.push_back(this->writeScratchVar(args[index]->type(),
3483                                                                    lvalue->load()));
3484             } else {
3485                 // Create a substitute argument, but leave it uninitialized.
3486                 substituteArgument.push_back(this->writeScratchVar(args[index]->type()));
3487             }
3488             writeback.push_back(std::move(lvalue));
3489         } else {
3490             substituteArgument.push_back(std::string());
3491             writeback.push_back(nullptr);
3492         }
3493     }
3494 
3495     std::string expr = this->assembleName(func.mangledName());
3496     expr.push_back('(');
3497     auto separator = SkSL::String::Separator();
3498 
3499     if (std::string funcDepArgs = this->functionDependencyArgs(func); !funcDepArgs.empty()) {
3500         expr += funcDepArgs;
3501         separator();
3502     }
3503 
3504     // Pass the function arguments, or any substitutes as needed.
3505     for (int index = 0; index < args.size(); ++index) {
3506         expr += separator();
3507         if (!substituteArgument[index].empty()) {
3508             // We need to take the address of the variable and pass it down as a pointer.
3509             expr += '&' + substituteArgument[index];
3510         } else if (args[index]->type().isSampler()) {
3511             // If the argument is a sampler, we need to pass the texture _and_ its associated
3512             // sampler. (Function parameter lists also convert sampler parameters into a matching
3513             // texture/sampler parameter pair.)
3514             expr += this->assembleExpression(*args[index], Precedence::kSequence);
3515             expr += kTextureSuffix;
3516             expr += ", ";
3517             expr += this->assembleExpression(*args[index], Precedence::kSequence);
3518             expr += kSamplerSuffix;
3519         } else if (args[index]->type().isUnsizedArray()) {
3520             // If the array is in the parameter storage space then manually just pass it through
3521             // since it is already a pointer.
3522             if (args[index]->is<VariableReference>()) {
3523                 const Variable* v = args[index]->as<VariableReference>().variable();
3524                 // A variable reference to an unsized array should always be a parameter,
3525                 // because unsized arrays coming from uniforms will have the `FieldAccess`
3526                 // expression type.
3527                 SkASSERT(v->storage() == Variable::Storage::kParameter);
3528                 expr += this->assembleName(v->mangledName());
3529             } else {
3530                 expr += "&(" + this->assembleExpression(*args[index], Precedence::kSequence) + ")";
3531             }
3532         } else {
3533             expr += this->assembleExpression(*args[index], Precedence::kSequence);
3534         }
3535     }
3536     expr += ')';
3537 
3538     if (call.type().isVoid()) {
3539         // Making function calls that result in `void` is only valid in on the left side of a
3540         // comma-sequence, or in a top-level statement. Emit the function call as a top-level
3541         // statement and return an empty string, as the result will not be used.
3542         SkASSERT(parentPrecedence >= Precedence::kSequence);
3543         this->write(expr);
3544         this->writeLine(";");
3545     } else {
3546         result = this->writeScratchLet(expr);
3547     }
3548 
3549     // Write the substitute arguments back into their lvalues.
3550     for (int index = 0; index < args.size(); ++index) {
3551         if (!substituteArgument[index].empty()) {
3552             this->writeLine(writeback[index]->store(substituteArgument[index]));
3553         }
3554     }
3555 
3556     // Return the result of invoking the function.
3557     return result;
3558 }
3559 
assembleIndexExpression(const IndexExpression & i)3560 std::string WGSLCodeGenerator::assembleIndexExpression(const IndexExpression& i) {
3561     // Put the index value into a let-expression.
3562     std::string idx = this->writeNontrivialScratchLet(*i.index(), Precedence::kExpression);
3563     return this->assembleExpression(*i.base(), Precedence::kPostfix) + "[" + idx + "]";
3564 }
3565 
assembleLiteral(const Literal & l)3566 std::string WGSLCodeGenerator::assembleLiteral(const Literal& l) {
3567     const Type& type = l.type();
3568     if (type.isFloat() || type.isBoolean()) {
3569         return l.description(OperatorPrecedence::kExpression);
3570     }
3571     SkASSERT(type.isInteger());
3572     if (type.matches(*fContext.fTypes.fUInt)) {
3573         return std::to_string(l.intValue() & 0xffffffff) + "u";
3574     } else if (type.matches(*fContext.fTypes.fUShort)) {
3575         return std::to_string(l.intValue() & 0xffff) + "u";
3576     } else {
3577         return std::to_string(l.intValue());
3578     }
3579 }
3580 
assembleIncrementExpr(const Type & type)3581 std::string WGSLCodeGenerator::assembleIncrementExpr(const Type& type) {
3582     // `type(`
3583     std::string expr = to_wgsl_type(fContext, type);
3584     expr.push_back('(');
3585 
3586     // `1, 1, 1...)`
3587     auto separator = SkSL::String::Separator();
3588     for (int slots = type.slotCount(); slots > 0; --slots) {
3589         expr += separator();
3590         expr += "1";
3591     }
3592     expr.push_back(')');
3593     return expr;
3594 }
3595 
assemblePrefixExpression(const PrefixExpression & p,Precedence parentPrecedence)3596 std::string WGSLCodeGenerator::assemblePrefixExpression(const PrefixExpression& p,
3597                                                         Precedence parentPrecedence) {
3598     // Unary + does nothing, so we omit it from the output.
3599     Operator op = p.getOperator();
3600     if (op.kind() == Operator::Kind::PLUS) {
3601         return this->assembleExpression(*p.operand(), Precedence::kPrefix);
3602     }
3603 
3604     // Pre-increment/decrement expressions have no direct equivalent in WGSL.
3605     if (op.kind() == Operator::Kind::PLUSPLUS || op.kind() == Operator::Kind::MINUSMINUS) {
3606         std::unique_ptr<LValue> lvalue = this->makeLValue(*p.operand());
3607         if (!lvalue) {
3608             return "";
3609         }
3610 
3611         // Generate the new value: `lvalue + type(1, 1, 1...)`.
3612         std::string newValue =
3613                 lvalue->load() +
3614                 (p.getOperator().kind() == Operator::Kind::PLUSPLUS ? " + " : " - ") +
3615                 this->assembleIncrementExpr(p.operand()->type());
3616         this->writeLine(lvalue->store(newValue));
3617         return lvalue->load();
3618     }
3619 
3620     // WGSL natively supports unary negation/not expressions (!,~,-).
3621     SkASSERT(op.kind() == OperatorKind::LOGICALNOT ||
3622              op.kind() == OperatorKind::BITWISENOT ||
3623              op.kind() == OperatorKind::MINUS);
3624 
3625     // The unary negation operator only applies to scalars and vectors. For other mathematical
3626     // objects (such as matrices) we can express it as a multiplication by -1.
3627     std::string expr;
3628     const bool needsNegation = op.kind() == Operator::Kind::MINUS &&
3629                                !p.operand()->type().isScalar() && !p.operand()->type().isVector();
3630     const bool needParens = needsNegation || Precedence::kPrefix >= parentPrecedence;
3631 
3632     if (needParens) {
3633         expr.push_back('(');
3634     }
3635 
3636     if (needsNegation) {
3637         expr += "-1.0 * ";
3638         expr += this->assembleExpression(*p.operand(), Precedence::kMultiplicative);
3639     } else {
3640         expr += p.getOperator().tightOperatorName();
3641         expr += this->assembleExpression(*p.operand(), Precedence::kPrefix);
3642     }
3643 
3644     if (needParens) {
3645         expr.push_back(')');
3646     }
3647 
3648     return expr;
3649 }
3650 
assemblePostfixExpression(const PostfixExpression & p,Precedence parentPrecedence)3651 std::string WGSLCodeGenerator::assemblePostfixExpression(const PostfixExpression& p,
3652                                                          Precedence parentPrecedence) {
3653     SkASSERT(p.getOperator().kind() == Operator::Kind::PLUSPLUS ||
3654              p.getOperator().kind() == Operator::Kind::MINUSMINUS);
3655 
3656     // Post-increment/decrement expressions have no direct equivalent in WGSL; they do exist as a
3657     // standalone statement for convenience, but these aren't the same as SkSL's post-increments.
3658     std::unique_ptr<LValue> lvalue = this->makeLValue(*p.operand());
3659     if (!lvalue) {
3660         return "";
3661     }
3662 
3663     // If the expression is used, create a let-copy of the original value.
3664     // (At statement-level precedence, we know the value is unused and can skip this let-copy.)
3665     std::string originalValue;
3666     if (parentPrecedence != Precedence::kStatement) {
3667         originalValue = this->writeScratchLet(lvalue->load());
3668     }
3669     // Generate the new value: `lvalue + type(1, 1, 1...)`.
3670     std::string newValue = lvalue->load() +
3671                            (p.getOperator().kind() == Operator::Kind::PLUSPLUS ? " + " : " - ") +
3672                            this->assembleIncrementExpr(p.operand()->type());
3673     this->writeLine(lvalue->store(newValue));
3674 
3675     return originalValue;
3676 }
3677 
assembleSwizzle(const Swizzle & swizzle)3678 std::string WGSLCodeGenerator::assembleSwizzle(const Swizzle& swizzle) {
3679     return this->assembleExpression(*swizzle.base(), Precedence::kPostfix) + "." +
3680            Swizzle::MaskString(swizzle.components());
3681 }
3682 
writeScratchVar(const Type & type,const std::string & value)3683 std::string WGSLCodeGenerator::writeScratchVar(const Type& type, const std::string& value) {
3684     std::string scratchVarName = "_skTemp" + std::to_string(fScratchCount++);
3685     this->write("var ");
3686     this->write(scratchVarName);
3687     this->write(": ");
3688     this->write(to_wgsl_type(fContext, type));
3689     if (!value.empty()) {
3690         this->write(" = ");
3691         this->write(value);
3692     }
3693     this->writeLine(";");
3694     return scratchVarName;
3695 }
3696 
writeScratchLet(const std::string & expr,bool isCompileTimeConstant)3697 std::string WGSLCodeGenerator::writeScratchLet(const std::string& expr,
3698                                                bool isCompileTimeConstant) {
3699     std::string scratchVarName = "_skTemp" + std::to_string(fScratchCount++);
3700     this->write(fAtFunctionScope && !isCompileTimeConstant ? "let " : "const ");
3701     this->write(scratchVarName);
3702     this->write(" = ");
3703     this->write(expr);
3704     this->writeLine(";");
3705     return scratchVarName;
3706 }
3707 
writeScratchLet(const Expression & expr,Precedence parentPrecedence)3708 std::string WGSLCodeGenerator::writeScratchLet(const Expression& expr,
3709                                                Precedence parentPrecedence) {
3710     return this->writeScratchLet(this->assembleExpression(expr, parentPrecedence));
3711 }
3712 
writeNontrivialScratchLet(const Expression & expr,Precedence parentPrecedence)3713 std::string WGSLCodeGenerator::writeNontrivialScratchLet(const Expression& expr,
3714                                                          Precedence parentPrecedence) {
3715     std::string result = this->assembleExpression(expr, parentPrecedence);
3716     return is_nontrivial_expression(expr)
3717                    ? this->writeScratchLet(result, Analysis::IsCompileTimeConstant(expr))
3718                    : result;
3719 }
3720 
assembleTernaryExpression(const TernaryExpression & t,Precedence parentPrecedence)3721 std::string WGSLCodeGenerator::assembleTernaryExpression(const TernaryExpression& t,
3722                                                          Precedence parentPrecedence) {
3723     std::string expr;
3724 
3725     // The trivial case is when neither branch has side effects and evaluate to a scalar or vector
3726     // type. This can be represented with a call to the WGSL `select` intrinsic. Select doesn't
3727     // support short-circuiting, so we should only use it when both the true- and false-expressions
3728     // are trivial to evaluate.
3729     if ((t.type().isScalar() || t.type().isVector()) &&
3730         !Analysis::HasSideEffects(*t.test()) &&
3731         Analysis::IsTrivialExpression(*t.ifTrue()) &&
3732         Analysis::IsTrivialExpression(*t.ifFalse())) {
3733 
3734         bool needParens = Precedence::kTernary >= parentPrecedence;
3735         if (needParens) {
3736             expr.push_back('(');
3737         }
3738         expr += "select(";
3739         expr += this->assembleExpression(*t.ifFalse(), Precedence::kSequence);
3740         expr += ", ";
3741         expr += this->assembleExpression(*t.ifTrue(), Precedence::kSequence);
3742         expr += ", ";
3743 
3744         bool isVector = t.type().isVector();
3745         if (isVector) {
3746             // Splat the condition expression into a vector.
3747             expr += String::printf("vec%d<bool>(", t.type().columns());
3748         }
3749         expr += this->assembleExpression(*t.test(), Precedence::kSequence);
3750         if (isVector) {
3751             expr.push_back(')');
3752         }
3753         expr.push_back(')');
3754         if (needParens) {
3755             expr.push_back(')');
3756         }
3757     } else {
3758         // WGSL does not support ternary expressions. Instead, we hoist the expression out into the
3759         // surrounding block, convert it into an if statement, and write the result to a synthesized
3760         // variable. Instead of the original expression, we return that variable.
3761         expr = this->writeScratchVar(t.ifTrue()->type());
3762 
3763         std::string testExpr = this->assembleExpression(*t.test(), Precedence::kExpression);
3764         this->write("if ");
3765         this->write(testExpr);
3766         this->writeLine(" {");
3767 
3768         ++fIndentation;
3769         std::string trueExpr = this->assembleExpression(*t.ifTrue(), Precedence::kAssignment);
3770         this->write(expr);
3771         this->write(" = ");
3772         this->write(trueExpr);
3773         this->writeLine(";");
3774         --fIndentation;
3775 
3776         this->writeLine("} else {");
3777 
3778         ++fIndentation;
3779         std::string falseExpr = this->assembleExpression(*t.ifFalse(), Precedence::kAssignment);
3780         this->write(expr);
3781         this->write(" = ");
3782         this->write(falseExpr);
3783         this->writeLine(";");
3784         --fIndentation;
3785 
3786         this->writeLine("}");
3787     }
3788     return expr;
3789 }
3790 
variablePrefix(const Variable & v)3791 std::string WGSLCodeGenerator::variablePrefix(const Variable& v) {
3792     if (v.storage() == Variable::Storage::kGlobal) {
3793         // If the field refers to a pipeline IO parameter, then we access it via the synthesized IO
3794         // structs. We make an explicit exception for `sk_PointSize` which we declare as a
3795         // placeholder variable in global scope as it is not supported by WebGPU as a pipeline IO
3796         // parameter (see comments in `writeStageOutputStruct`).
3797         if (v.modifierFlags() & ModifierFlag::kIn) {
3798             return "_stageIn.";
3799         }
3800         if (v.modifierFlags() & ModifierFlag::kOut) {
3801             return "(*_stageOut).";
3802         }
3803 
3804         // If the field refers to an anonymous-interface-block structure, access it via the
3805         // synthesized `_uniform0` or `_storage1` global.
3806         if (const InterfaceBlock* ib = v.interfaceBlock()) {
3807             const Type& ibType = ib->var()->type().componentType();
3808             if (const std::string* ibName = fInterfaceBlockNameMap.find(&ibType)) {
3809                 return *ibName + '.';
3810             }
3811         }
3812 
3813         // If the field refers to an top-level uniform, access it via the synthesized
3814         // `_globalUniforms` global. (Note that this should only occur in test code; Skia will
3815         // always put uniforms in an interface block.)
3816         if (is_in_global_uniforms(v)) {
3817             return "_globalUniforms.";
3818         }
3819     }
3820 
3821     return "";
3822 }
3823 
variableReferenceNameForLValue(const VariableReference & r)3824 std::string WGSLCodeGenerator::variableReferenceNameForLValue(const VariableReference& r) {
3825     const Variable& v = *r.variable();
3826 
3827     if (v.storage() == Variable::Storage::kParameter &&
3828          (v.modifierFlags() & ModifierFlag::kOut || v.type().isUnsizedArray())) {
3829         // This is an out-parameter or unsized array parameter; it's pointer-typed, so we need to
3830         // dereference it. We wrap the dereference in parentheses, in case the value is used in an
3831         // access expression later.
3832         return "(*" + this->assembleName(v.mangledName()) + ')';
3833     }
3834 
3835     return this->variablePrefix(v) + this->assembleName(v.mangledName());
3836 }
3837 
assembleVariableReference(const VariableReference & r)3838 std::string WGSLCodeGenerator::assembleVariableReference(const VariableReference& r) {
3839     // TODO(b/294274678): Correctly handle RTFlip for built-ins.
3840     const Variable& v = *r.variable();
3841 
3842     // Insert a conversion expression if this is a built-in variable whose type differs from the
3843     // SkSL.
3844     std::string expr;
3845     std::optional<std::string_view> conversion = needs_builtin_type_conversion(v);
3846     if (conversion.has_value()) {
3847         expr += *conversion;
3848         expr.push_back('(');
3849     }
3850 
3851     expr += this->variableReferenceNameForLValue(r);
3852 
3853     if (conversion.has_value()) {
3854         expr.push_back(')');
3855     }
3856 
3857     return expr;
3858 }
3859 
assembleAnyConstructor(const AnyConstructor & c)3860 std::string WGSLCodeGenerator::assembleAnyConstructor(const AnyConstructor& c) {
3861     std::string expr = to_wgsl_type(fContext, c.type());
3862     expr.push_back('(');
3863     auto separator = SkSL::String::Separator();
3864     for (const auto& e : c.argumentSpan()) {
3865         expr += separator();
3866         expr += this->assembleExpression(*e, Precedence::kSequence);
3867     }
3868     expr.push_back(')');
3869     return expr;
3870 }
3871 
assembleConstructorCompound(const ConstructorCompound & c)3872 std::string WGSLCodeGenerator::assembleConstructorCompound(const ConstructorCompound& c) {
3873     if (c.type().isVector()) {
3874         return this->assembleConstructorCompoundVector(c);
3875     } else if (c.type().isMatrix()) {
3876         return this->assembleConstructorCompoundMatrix(c);
3877     } else {
3878         fContext.fErrors->error(c.fPosition, "unsupported compound constructor");
3879         return {};
3880     }
3881 }
3882 
assembleConstructorCompoundVector(const ConstructorCompound & c)3883 std::string WGSLCodeGenerator::assembleConstructorCompoundVector(const ConstructorCompound& c) {
3884     // WGSL supports constructing vectors from a mix of scalars and vectors but
3885     // not matrices (see https://www.w3.org/TR/WGSL/#type-constructor-expr).
3886     //
3887     // SkSL supports vec4(mat2x2) which we handle specially.
3888     if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
3889         const Expression& arg = *c.argumentSpan().front();
3890         if (arg.type().isMatrix()) {
3891             SkASSERT(arg.type().columns() == 2);
3892             SkASSERT(arg.type().rows() == 2);
3893 
3894             std::string matrix = this->writeNontrivialScratchLet(arg, Precedence::kPostfix);
3895             return String::printf("%s(%s[0], %s[1])", to_wgsl_type(fContext, c.type()).c_str(),
3896                                                       matrix.c_str(),
3897                                                       matrix.c_str());
3898         }
3899     }
3900     return this->assembleAnyConstructor(c);
3901 }
3902 
assembleConstructorCompoundMatrix(const ConstructorCompound & ctor)3903 std::string WGSLCodeGenerator::assembleConstructorCompoundMatrix(const ConstructorCompound& ctor) {
3904     SkASSERT(ctor.type().isMatrix());
3905 
3906     std::string expr = to_wgsl_type(fContext, ctor.type()) + '(';
3907     auto separator = String::Separator();
3908     for (const std::unique_ptr<Expression>& arg : ctor.arguments()) {
3909         SkASSERT(arg->type().isScalar() || arg->type().isVector());
3910 
3911         if (arg->type().isScalar()) {
3912             expr += separator();
3913             expr += this->assembleExpression(*arg, Precedence::kSequence);
3914         } else {
3915             std::string inner = this->writeNontrivialScratchLet(*arg, Precedence::kSequence);
3916             int numSlots = arg->type().slotCount();
3917             for (int slot = 0; slot < numSlots; ++slot) {
3918                 String::appendf(&expr, "%s%s[%d]", separator().c_str(), inner.c_str(), slot);
3919             }
3920         }
3921     }
3922     return expr + ')';
3923 }
3924 
assembleConstructorDiagonalMatrix(const ConstructorDiagonalMatrix & c)3925 std::string WGSLCodeGenerator::assembleConstructorDiagonalMatrix(
3926         const ConstructorDiagonalMatrix& c) {
3927     const Type& type = c.type();
3928     SkASSERT(type.isMatrix());
3929     SkASSERT(c.argument()->type().isScalar());
3930 
3931     // Evaluate the inner-expression, creating a scratch variable if necessary.
3932     std::string inner = this->writeNontrivialScratchLet(*c.argument(), Precedence::kAssignment);
3933 
3934     // Assemble a diagonal-matrix expression.
3935     std::string expr = to_wgsl_type(fContext, type) + '(';
3936     auto separator = String::Separator();
3937     for (int col = 0; col < type.columns(); ++col) {
3938         for (int row = 0; row < type.rows(); ++row) {
3939             expr += separator();
3940             if (col == row) {
3941                 expr += inner;
3942             } else {
3943                 expr += "0.0";
3944             }
3945         }
3946     }
3947     return expr + ')';
3948 }
3949 
assembleConstructorMatrixResize(const ConstructorMatrixResize & ctor)3950 std::string WGSLCodeGenerator::assembleConstructorMatrixResize(
3951         const ConstructorMatrixResize& ctor) {
3952     std::string source = this->writeNontrivialScratchLet(*ctor.argument(), Precedence::kSequence);
3953     int columns = ctor.type().columns();
3954     int rows = ctor.type().rows();
3955     int sourceColumns = ctor.argument()->type().columns();
3956     int sourceRows = ctor.argument()->type().rows();
3957     auto separator = String::Separator();
3958     std::string expr = to_wgsl_type(fContext, ctor.type()) + '(';
3959 
3960     for (int c = 0; c < columns; ++c) {
3961         for (int r = 0; r < rows; ++r) {
3962             expr += separator();
3963             if (c < sourceColumns && r < sourceRows) {
3964                 String::appendf(&expr, "%s[%d][%d]", source.c_str(), c, r);
3965             } else if (r == c) {
3966                 expr += "1.0";
3967             } else {
3968                 expr += "0.0";
3969             }
3970         }
3971     }
3972 
3973     return expr + ')';
3974 }
3975 
assembleEqualityExpression(const Type & left,const std::string & leftName,const Type & right,const std::string & rightName,Operator op,Precedence parentPrecedence)3976 std::string WGSLCodeGenerator::assembleEqualityExpression(const Type& left,
3977                                                           const std::string& leftName,
3978                                                           const Type& right,
3979                                                           const std::string& rightName,
3980                                                           Operator op,
3981                                                           Precedence parentPrecedence) {
3982     SkASSERT(op.kind() == OperatorKind::EQEQ || op.kind() == OperatorKind::NEQ);
3983 
3984     std::string expr;
3985     bool isEqual = (op.kind() == Operator::Kind::EQEQ);
3986     const char* const combiner = isEqual ? " && " : " || ";
3987 
3988     if (left.isMatrix()) {
3989         // Each matrix column must be compared as if it were an individual vector.
3990         SkASSERT(right.isMatrix());
3991         SkASSERT(left.rows() == right.rows());
3992         SkASSERT(left.columns() == right.columns());
3993         int columns = left.columns();
3994         const Type& vecType = left.columnType(fContext);
3995         const char* separator = "(";
3996         for (int index = 0; index < columns; ++index) {
3997             expr += separator;
3998             std::string suffix = '[' + std::to_string(index) + ']';
3999             expr += this->assembleEqualityExpression(vecType, leftName + suffix,
4000                                                      vecType, rightName + suffix,
4001                                                      op, Precedence::kParentheses);
4002             separator = combiner;
4003         }
4004         return expr + ')';
4005     }
4006 
4007     if (left.isArray()) {
4008         SkASSERT(right.matches(left));
4009         const Type& indexedType = left.componentType();
4010         const char* separator = "(";
4011         for (int index = 0; index < left.columns(); ++index) {
4012             expr += separator;
4013             std::string suffix = '[' + std::to_string(index) + ']';
4014             expr += this->assembleEqualityExpression(indexedType, leftName + suffix,
4015                                                      indexedType, rightName + suffix,
4016                                                      op, Precedence::kParentheses);
4017             separator = combiner;
4018         }
4019         return expr + ')';
4020     }
4021 
4022     if (left.isStruct()) {
4023         // Recursively compare every field in the struct.
4024         SkASSERT(right.matches(left));
4025         SkSpan<const Field> fields = left.fields();
4026 
4027         const char* separator = "(";
4028         for (const Field& field : fields) {
4029             expr += separator;
4030             expr += this->assembleEqualityExpression(
4031                             *field.fType, leftName + '.' + this->assembleName(field.fName),
4032                             *field.fType, rightName + '.' + this->assembleName(field.fName),
4033                             op, Precedence::kParentheses);
4034             separator = combiner;
4035         }
4036         return expr + ')';
4037     }
4038 
4039     if (left.isVector()) {
4040         // Compare vectors via `all(x == y)` or `any(x != y)`.
4041         SkASSERT(right.isVector());
4042         SkASSERT(left.slotCount() == right.slotCount());
4043 
4044         expr += isEqual ? "all(" : "any(";
4045         expr += leftName;
4046         expr += operator_name(op);
4047         expr += rightName;
4048         return expr + ')';
4049     }
4050 
4051     // Compare scalars via `x == y`.
4052     SkASSERT(right.isScalar());
4053     if (parentPrecedence < Precedence::kSequence) {
4054         expr = '(';
4055     }
4056     expr += leftName;
4057     expr += operator_name(op);
4058     expr += rightName;
4059     if (parentPrecedence < Precedence::kSequence) {
4060         expr += ')';
4061     }
4062     return expr;
4063 }
4064 
assembleEqualityExpression(const Expression & left,const Expression & right,Operator op,Precedence parentPrecedence)4065 std::string WGSLCodeGenerator::assembleEqualityExpression(const Expression& left,
4066                                                           const Expression& right,
4067                                                           Operator op,
4068                                                           Precedence parentPrecedence) {
4069     std::string leftName, rightName;
4070     if (left.type().isScalar() || left.type().isVector()) {
4071         // WGSL supports scalar and vector comparisons natively. We know the expressions will only
4072         // be emitted once, so there isn't a benefit to creating a let-declaration.
4073         leftName = this->assembleExpression(left, Precedence::kParentheses);
4074         rightName = this->assembleExpression(right, Precedence::kParentheses);
4075     } else {
4076         leftName = this->writeNontrivialScratchLet(left, Precedence::kAssignment);
4077         rightName = this->writeNontrivialScratchLet(right, Precedence::kAssignment);
4078     }
4079     return this->assembleEqualityExpression(left.type(), leftName, right.type(), rightName,
4080                                             op, parentPrecedence);
4081 }
4082 
writeProgramElement(const ProgramElement & e)4083 void WGSLCodeGenerator::writeProgramElement(const ProgramElement& e) {
4084     switch (e.kind()) {
4085         case ProgramElement::Kind::kExtension:
4086             // TODO(skia:13092): WGSL supports extensions via the "enable" directive
4087             // (https://www.w3.org/TR/WGSL/#enable-extensions-sec ). While we could easily emit this
4088             // directive, we should first ensure that all possible SkSL extension names are
4089             // converted to their appropriate WGSL extension.
4090             break;
4091         case ProgramElement::Kind::kGlobalVar:
4092             this->writeGlobalVarDeclaration(e.as<GlobalVarDeclaration>());
4093             break;
4094         case ProgramElement::Kind::kInterfaceBlock:
4095             // All interface block declarations are handled explicitly as the "program header" in
4096             // generateCode().
4097             break;
4098         case ProgramElement::Kind::kStructDefinition:
4099             this->writeStructDefinition(e.as<StructDefinition>());
4100             break;
4101         case ProgramElement::Kind::kFunctionPrototype:
4102             // A WGSL function declaration must contain its body and the function name is in scope
4103             // for the entire program (see https://www.w3.org/TR/WGSL/#function-declaration and
4104             // https://www.w3.org/TR/WGSL/#declaration-and-scope).
4105             //
4106             // As such, we don't emit function prototypes.
4107             break;
4108         case ProgramElement::Kind::kFunction:
4109             this->writeFunction(e.as<FunctionDefinition>());
4110             break;
4111         case ProgramElement::Kind::kModifiers:
4112             this->writeModifiersDeclaration(e.as<ModifiersDeclaration>());
4113             break;
4114         default:
4115             SkDEBUGFAILF("unsupported program element: %s\n", e.description().c_str());
4116             break;
4117     }
4118 }
4119 
writeTextureOrSampler(const Variable & var,int bindingLocation,std::string_view suffix,std::string_view wgslType)4120 void WGSLCodeGenerator::writeTextureOrSampler(const Variable& var,
4121                                               int bindingLocation,
4122                                               std::string_view suffix,
4123                                               std::string_view wgslType) {
4124     if (var.type().dimensions() != SpvDim2D) {
4125         // Skia currently only uses 2D textures.
4126         fContext.fErrors->error(var.varDeclaration()->position(), "unsupported texture dimensions");
4127         return;
4128     }
4129 
4130     this->write("@group(");
4131     this->write(std::to_string(std::max(0, var.layout().fSet)));
4132     this->write(") @binding(");
4133     this->write(std::to_string(bindingLocation));
4134     this->write(") var ");
4135     this->write(this->assembleName(var.mangledName()));
4136     this->write(suffix);
4137     this->write(": ");
4138     this->write(wgslType);
4139     this->writeLine(";");
4140 }
4141 
writeGlobalVarDeclaration(const GlobalVarDeclaration & d)4142 void WGSLCodeGenerator::writeGlobalVarDeclaration(const GlobalVarDeclaration& d) {
4143     const VarDeclaration& decl = d.varDeclaration();
4144     const Variable& var = *decl.var();
4145     if ((var.modifierFlags() & (ModifierFlag::kIn | ModifierFlag::kOut)) ||
4146         is_in_global_uniforms(var)) {
4147         // Pipeline stage I/O parameters and top-level (non-block) uniforms are handled specially
4148         // in generateCode().
4149         return;
4150     }
4151 
4152     const Type::TypeKind varKind = var.type().typeKind();
4153     if (varKind == Type::TypeKind::kSampler) {
4154         // If the sampler binding was unassigned, provide a scratch value; this will make
4155         // golden-output tests pass, but will not actually be usable for drawing.
4156         int samplerLocation = var.layout().fSampler >= 0 ? var.layout().fSampler
4157                                                          : 10000 + fScratchCount++;
4158         this->writeTextureOrSampler(var, samplerLocation, kSamplerSuffix, "sampler");
4159 
4160         // If the texture binding was unassigned, provide a scratch value (for golden-output tests).
4161         int textureLocation = var.layout().fTexture >= 0 ? var.layout().fTexture
4162                                                          : 10000 + fScratchCount++;
4163         this->writeTextureOrSampler(var, textureLocation, kTextureSuffix, "texture_2d<f32>");
4164         return;
4165     }
4166 
4167     if (varKind == Type::TypeKind::kTexture) {
4168         // If a binding location was unassigned, provide a scratch value (for golden-output tests).
4169         int textureLocation = var.layout().fBinding >= 0 ? var.layout().fBinding
4170                                                          : 10000 + fScratchCount++;
4171         // For a texture without an associated sampler, we don't apply a suffix.
4172         this->writeTextureOrSampler(var, textureLocation, /*suffix=*/"",
4173                                     to_wgsl_type(fContext, var.type(), &var.layout()));
4174         return;
4175     }
4176 
4177     std::string initializer;
4178     if (decl.value()) {
4179         // We assume here that the initial-value expression will not emit any helper statements.
4180         // Initial-value expressions are required to pass IsConstantExpression, which limits the
4181         // blast radius to constructors, literals, and other constant values/variables.
4182         initializer += " = ";
4183         initializer += this->assembleExpression(*decl.value(), Precedence::kAssignment);
4184     }
4185 
4186     if (var.modifierFlags().isConst()) {
4187         this->write("const ");
4188     } else if (var.modifierFlags().isWorkgroup()) {
4189         this->write("var<workgroup> ");
4190     } else if (var.modifierFlags().isPixelLocal()) {
4191         this->write("var<pixel_local> ");
4192     } else {
4193         this->write("var<private> ");
4194     }
4195     this->write(this->assembleName(var.mangledName()));
4196     this->write(": " + to_wgsl_type(fContext, var.type(), &var.layout()));
4197     this->write(initializer);
4198     this->writeLine(";");
4199 }
4200 
writeStructDefinition(const StructDefinition & s)4201 void WGSLCodeGenerator::writeStructDefinition(const StructDefinition& s) {
4202     const Type& type = s.type();
4203     this->writeLine("struct " + type.displayName() + " {");
4204     this->writeFields(type.fields(), /*memoryLayout=*/nullptr);
4205     this->writeLine("};");
4206 }
4207 
writeModifiersDeclaration(const ModifiersDeclaration & modifiers)4208 void WGSLCodeGenerator::writeModifiersDeclaration(const ModifiersDeclaration& modifiers) {
4209     LayoutFlags flags = modifiers.layout().fFlags;
4210     flags &= ~(LayoutFlag::kLocalSizeX | LayoutFlag::kLocalSizeY | LayoutFlag::kLocalSizeZ);
4211     if (flags != LayoutFlag::kNone) {
4212         fContext.fErrors->error(modifiers.position(), "unsupported declaration");
4213         return;
4214     }
4215 
4216     if (modifiers.layout().fLocalSizeX >= 0) {
4217         fLocalSizeX = modifiers.layout().fLocalSizeX;
4218     }
4219     if (modifiers.layout().fLocalSizeY >= 0) {
4220         fLocalSizeY = modifiers.layout().fLocalSizeY;
4221     }
4222     if (modifiers.layout().fLocalSizeZ >= 0) {
4223         fLocalSizeZ = modifiers.layout().fLocalSizeZ;
4224     }
4225 }
4226 
writeFields(SkSpan<const Field> fields,const MemoryLayout * memoryLayout)4227 void WGSLCodeGenerator::writeFields(SkSpan<const Field> fields, const MemoryLayout* memoryLayout) {
4228     fIndentation++;
4229 
4230     // TODO(skia:14370): array uniforms may need manual fixup for std140 padding. (Those uniforms
4231     // will also need special handling when they are accessed, or passed to functions.)
4232     for (size_t index = 0; index < fields.size(); ++index) {
4233         const Field& field = fields[index];
4234         if (memoryLayout && !memoryLayout->isSupported(*field.fType)) {
4235             // Reject types that aren't supported by the memory layout.
4236             fContext.fErrors->error(field.fPosition, "type '" + std::string(field.fType->name()) +
4237                                                      "' is not permitted here");
4238             return;
4239         }
4240 
4241         // Prepend @size(n) to enforce the offsets from the SkSL layout. (This is effectively
4242         // a gadget that we can use to insert padding between elements.)
4243         if (index < fields.size() - 1) {
4244             int thisFieldOffset = field.fLayout.fOffset;
4245             int nextFieldOffset = fields[index + 1].fLayout.fOffset;
4246             if (index == 0 && thisFieldOffset > 0) {
4247                 fContext.fErrors->error(field.fPosition, "field must have an offset of zero");
4248                 return;
4249             }
4250             if (thisFieldOffset >= 0 && nextFieldOffset > thisFieldOffset) {
4251                 this->write("@size(");
4252                 this->write(std::to_string(nextFieldOffset - thisFieldOffset));
4253                 this->write(") ");
4254             }
4255         }
4256 
4257         this->write(this->assembleName(field.fName));
4258         this->write(": ");
4259         if (const FieldPolyfillInfo* info = fFieldPolyfillMap.find(&field)) {
4260             if (info->fIsArray) {
4261                 // This properly handles arrays of matrices, as well as arrays of other primitives.
4262                 SkASSERT(field.fType->isArray());
4263                 this->write("array<_skArrayElement_");
4264                 this->write(field.fType->abbreviatedName());
4265                 this->write(", ");
4266                 this->write(std::to_string(field.fType->columns()));
4267                 this->write(">");
4268             } else if (info->fIsMatrix) {
4269                 this->write("_skMatrix");
4270                 this->write(std::to_string(field.fType->columns()));
4271                 this->write(std::to_string(field.fType->rows()));
4272             } else {
4273                 SkDEBUGFAILF("need polyfill for %s", info->fReplacementName.c_str());
4274             }
4275         } else {
4276             this->write(to_wgsl_type(fContext, *field.fType, &field.fLayout));
4277         }
4278         this->writeLine(",");
4279     }
4280 
4281     fIndentation--;
4282 }
4283 
writeEnables()4284 void WGSLCodeGenerator::writeEnables() {
4285     this->writeLine("diagnostic(off, derivative_uniformity);");
4286     this->writeLine("diagnostic(off, chromium.unreachable_code);");
4287 
4288     if (fRequirements.fPixelLocalExtension) {
4289         this->writeLine("enable chromium_experimental_pixel_local;");
4290     }
4291     if (fProgram.fInterface.fUseLastFragColor) {
4292         this->writeLine("enable chromium_experimental_framebuffer_fetch;");
4293     }
4294     if (fProgram.fInterface.fOutputSecondaryColor) {
4295         this->writeLine("enable dual_source_blending;");
4296     }
4297 }
4298 
needsStageInputStruct() const4299 bool WGSLCodeGenerator::needsStageInputStruct() const {
4300     // It is illegal to declare a struct with no members; we can't emit a placeholder empty stage
4301     // input struct.
4302     return !fPipelineInputs.empty();
4303 }
4304 
writeStageInputStruct()4305 void WGSLCodeGenerator::writeStageInputStruct() {
4306     if (!this->needsStageInputStruct()) {
4307         return;
4308     }
4309 
4310     std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4311     SkASSERT(!structNamePrefix.empty());
4312 
4313     this->write("struct ");
4314     this->write(structNamePrefix);
4315     this->writeLine("In {");
4316     fIndentation++;
4317 
4318     for (const Variable* v : fPipelineInputs) {
4319         if (v->type().isInterfaceBlock()) {
4320             for (const Field& f : v->type().fields()) {
4321                 this->writePipelineIODeclaration(f.fLayout, *f.fType, f.fModifierFlags, f.fName,
4322                                                  Delimiter::kComma);
4323             }
4324         } else {
4325             this->writePipelineIODeclaration(v->layout(), v->type(), v->modifierFlags(),
4326                                              v->mangledName(), Delimiter::kComma);
4327         }
4328     }
4329 
4330     fIndentation--;
4331     this->writeLine("};");
4332 }
4333 
needsStageOutputStruct() const4334 bool WGSLCodeGenerator::needsStageOutputStruct() const {
4335     // It is illegal to declare a struct with no members. However, vertex programs will _always_
4336     // have an output stage in WGSL, because the spec requires them to emit `@builtin(position)`.
4337     // So we always synthesize a reference to `sk_Position` even if the program doesn't need it.
4338     return !fPipelineOutputs.empty() || ProgramConfig::IsVertex(fProgram.fConfig->fKind);
4339 }
4340 
writeStageOutputStruct()4341 void WGSLCodeGenerator::writeStageOutputStruct() {
4342     if (!this->needsStageOutputStruct()) {
4343         return;
4344     }
4345 
4346     std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4347     SkASSERT(!structNamePrefix.empty());
4348 
4349     this->write("struct ");
4350     this->write(structNamePrefix);
4351     this->writeLine("Out {");
4352     fIndentation++;
4353 
4354     bool declaredPositionBuiltin = false;
4355     bool requiresPointSizeBuiltin = false;
4356     for (const Variable* v : fPipelineOutputs) {
4357         if (v->type().isInterfaceBlock()) {
4358             for (const auto& f : v->type().fields()) {
4359                 this->writePipelineIODeclaration(f.fLayout, *f.fType, f.fModifierFlags, f.fName,
4360                                                  Delimiter::kComma);
4361                 if (f.fLayout.fBuiltin == SK_POSITION_BUILTIN) {
4362                     declaredPositionBuiltin = true;
4363                 } else if (f.fLayout.fBuiltin == SK_POINTSIZE_BUILTIN) {
4364                     // sk_PointSize is explicitly not supported by `builtin_from_sksl_name` so
4365                     // writePipelineIODeclaration will never write it. We mark it here if the
4366                     // declaration is needed so we can synthesize it below.
4367                     requiresPointSizeBuiltin = true;
4368                 }
4369             }
4370         } else {
4371             this->writePipelineIODeclaration(v->layout(), v->type(), v->modifierFlags(),
4372                                              v->mangledName(), Delimiter::kComma);
4373         }
4374     }
4375 
4376     // A vertex program must include the `position` builtin in its entrypoint's return type.
4377     const bool positionBuiltinRequired = ProgramConfig::IsVertex(fProgram.fConfig->fKind);
4378     if (positionBuiltinRequired && !declaredPositionBuiltin) {
4379         this->writeLine("@builtin(position) sk_Position: vec4<f32>,");
4380     }
4381 
4382     fIndentation--;
4383     this->writeLine("};");
4384 
4385     // In WebGPU/WGSL, the vertex stage does not support a point-size output and the size
4386     // of a point primitive is always 1 pixel (see https://github.com/gpuweb/gpuweb/issues/332).
4387     //
4388     // There isn't anything we can do to emulate this correctly at this stage so we synthesize a
4389     // placeholder global variable that has no effect. Programs should not rely on sk_PointSize when
4390     // using the Dawn backend.
4391     if (ProgramConfig::IsVertex(fProgram.fConfig->fKind) && requiresPointSizeBuiltin) {
4392         this->writeLine("/* unsupported */ var<private> sk_PointSize: f32;");
4393     }
4394 }
4395 
prepareUniformPolyfillsForInterfaceBlock(const InterfaceBlock * interfaceBlock,std::string_view instanceName,MemoryLayout::Standard nativeLayout)4396 void WGSLCodeGenerator::prepareUniformPolyfillsForInterfaceBlock(
4397         const InterfaceBlock* interfaceBlock,
4398         std::string_view instanceName,
4399         MemoryLayout::Standard nativeLayout) {
4400     SkSL::MemoryLayout std140(MemoryLayout::Standard::k140);
4401     SkSL::MemoryLayout native(nativeLayout);
4402 
4403     const Type& structType = interfaceBlock->var()->type().componentType();
4404     for (const Field& field : structType.fields()) {
4405         const Type* type = field.fType;
4406         bool needsArrayPolyfill = false;
4407         bool needsMatrixPolyfill = false;
4408 
4409         auto isPolyfillableMatrixType = [&](const Type* type) {
4410             return type->isMatrix() && std140.stride(*type) != native.stride(*type);
4411         };
4412 
4413         if (isPolyfillableMatrixType(type)) {
4414             // Matrices will be represented as 16-byte aligned arrays in std140, and reconstituted
4415             // into proper matrices as they are later accessed. We need to synthesize polyfill.
4416             needsMatrixPolyfill = true;
4417         } else if (type->isArray() && !type->isUnsizedArray() &&
4418                    !type->componentType().isOpaque()) {
4419             const Type* innerType = &type->componentType();
4420             if (isPolyfillableMatrixType(innerType)) {
4421                 // Use a polyfill when the array contains a matrix that requires polyfill.
4422                 needsArrayPolyfill = true;
4423                 needsMatrixPolyfill = true;
4424             } else if (native.size(*innerType) < 16) {
4425                 // Use a polyfill when the array elements are smaller than 16 bytes, since std140
4426                 // will pad elements to a 16-byte stride.
4427                 needsArrayPolyfill = true;
4428             }
4429         }
4430 
4431         if (needsArrayPolyfill || needsMatrixPolyfill) {
4432             // Add a polyfill for this matrix type.
4433             FieldPolyfillInfo info;
4434             info.fInterfaceBlock = interfaceBlock;
4435             info.fReplacementName = "_skUnpacked_" + std::string(instanceName) + '_' +
4436                                     this->assembleName(field.fName);
4437             info.fIsArray = needsArrayPolyfill;
4438             info.fIsMatrix = needsMatrixPolyfill;
4439             fFieldPolyfillMap.set(&field, info);
4440         }
4441     }
4442 }
4443 
writeUniformsAndBuffers()4444 void WGSLCodeGenerator::writeUniformsAndBuffers() {
4445     for (const ProgramElement* e : fProgram.elements()) {
4446         // Iterate through the interface blocks.
4447         if (!e->is<InterfaceBlock>()) {
4448             continue;
4449         }
4450         const InterfaceBlock& ib = e->as<InterfaceBlock>();
4451 
4452         // Determine if this interface block holds uniforms, buffers, or something else (skip it).
4453         std::string_view addressSpace;
4454         std::string_view accessMode;
4455         MemoryLayout::Standard nativeLayout;
4456         if (ib.var()->modifierFlags().isUniform()) {
4457             addressSpace = "uniform";
4458             nativeLayout = MemoryLayout::Standard::kWGSLUniform_Base;
4459         } else if (ib.var()->modifierFlags().isBuffer()) {
4460             addressSpace = "storage";
4461             nativeLayout = MemoryLayout::Standard::kWGSLStorage_Base;
4462             accessMode = ib.var()->modifierFlags().isReadOnly() ? ", read" : ", read_write";
4463         } else {
4464             continue;
4465         }
4466 
4467         // If we have an anonymous interface block, assign a name like `_uniform0` or `_storage1`.
4468         std::string instanceName;
4469         if (ib.instanceName().empty()) {
4470             instanceName = "_" + std::string(addressSpace) + std::to_string(fScratchCount++);
4471             fInterfaceBlockNameMap[&ib.var()->type().componentType()] = instanceName;
4472         } else {
4473             instanceName = std::string(ib.instanceName());
4474         }
4475 
4476         this->prepareUniformPolyfillsForInterfaceBlock(&ib, instanceName, nativeLayout);
4477 
4478         // Create a struct to hold all of the fields from this InterfaceBlock.
4479         SkASSERT(!ib.typeName().empty());
4480         this->write("struct ");
4481         this->write(ib.typeName());
4482         this->writeLine(" {");
4483 
4484         // Find the struct type and fields used by this interface block.
4485         const Type& ibType = ib.var()->type().componentType();
4486         SkASSERT(ibType.isStruct());
4487 
4488         SkSpan<const Field> ibFields = ibType.fields();
4489         SkASSERT(!ibFields.empty());
4490 
4491         MemoryLayout layout(MemoryLayout::Standard::k140);
4492         this->writeFields(ibFields, &layout);
4493         this->writeLine("};");
4494         this->write("@group(");
4495         this->write(std::to_string(std::max(0, ib.var()->layout().fSet)));
4496         this->write(") @binding(");
4497         this->write(std::to_string(std::max(0, ib.var()->layout().fBinding)));
4498         this->write(") var<");
4499         this->write(addressSpace);
4500         this->write(accessMode);
4501         this->write("> ");
4502         this->write(instanceName);
4503         this->write(" : ");
4504         this->write(to_wgsl_type(fContext, ib.var()->type(), &ib.var()->layout()));
4505         this->writeLine(";");
4506     }
4507 }
4508 
writeNonBlockUniformsForTests()4509 void WGSLCodeGenerator::writeNonBlockUniformsForTests() {
4510     bool declaredUniformsStruct = false;
4511 
4512     for (const ProgramElement* e : fProgram.elements()) {
4513         if (e->is<GlobalVarDeclaration>()) {
4514             const GlobalVarDeclaration& decls = e->as<GlobalVarDeclaration>();
4515             const Variable& var = *decls.varDeclaration().var();
4516             if (is_in_global_uniforms(var)) {
4517                 if (!declaredUniformsStruct) {
4518                     this->write("struct _GlobalUniforms {\n");
4519                     declaredUniformsStruct = true;
4520                 }
4521                 this->write("  ");
4522                 this->writeVariableDecl(var.layout(), var.type(), var.mangledName(),
4523                                         Delimiter::kComma);
4524             }
4525         }
4526     }
4527     if (declaredUniformsStruct) {
4528         int binding = fProgram.fConfig->fSettings.fDefaultUniformBinding;
4529         int set = fProgram.fConfig->fSettings.fDefaultUniformSet;
4530         this->write("};\n");
4531         this->write("@binding(" + std::to_string(binding) + ") ");
4532         this->write("@group(" + std::to_string(set) + ") ");
4533         this->writeLine("var<uniform> _globalUniforms: _GlobalUniforms;");
4534     }
4535 }
4536 
functionDependencyArgs(const FunctionDeclaration & f)4537 std::string WGSLCodeGenerator::functionDependencyArgs(const FunctionDeclaration& f) {
4538     WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&f);
4539     std::string args;
4540     if (deps && *deps) {
4541         const char* separator = "";
4542         if (*deps & WGSLFunctionDependency::kPipelineInputs) {
4543             args += "_stageIn";
4544             separator = ", ";
4545         }
4546         if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
4547             args += separator;
4548             args += "_stageOut";
4549         }
4550     }
4551     return args;
4552 }
4553 
writeFunctionDependencyParams(const FunctionDeclaration & f)4554 bool WGSLCodeGenerator::writeFunctionDependencyParams(const FunctionDeclaration& f) {
4555     WGSLFunctionDependencies* deps = fRequirements.fDependencies.find(&f);
4556     if (!deps || !*deps) {
4557         return false;
4558     }
4559 
4560     std::string_view structNamePrefix = pipeline_struct_prefix(fProgram.fConfig->fKind);
4561     if (structNamePrefix.empty()) {
4562         return false;
4563     }
4564     const char* separator = "";
4565     if (*deps & WGSLFunctionDependency::kPipelineInputs) {
4566         this->write("_stageIn: ");
4567         separator = ", ";
4568         this->write(structNamePrefix);
4569         this->write("In");
4570     }
4571     if (*deps & WGSLFunctionDependency::kPipelineOutputs) {
4572         this->write(separator);
4573         this->write("_stageOut: ptr<function, ");
4574         this->write(structNamePrefix);
4575         this->write("Out>");
4576     }
4577     return true;
4578 }
4579 
ToWGSL(Program & program,const ShaderCaps * caps,OutputStream & out,PrettyPrint pp,IncludeSyntheticCode isc,ValidateWGSLProc validateWGSL)4580 bool ToWGSL(Program& program,
4581             const ShaderCaps* caps,
4582             OutputStream& out,
4583             PrettyPrint pp,
4584             IncludeSyntheticCode isc,
4585             ValidateWGSLProc validateWGSL) {
4586     TRACE_EVENT0("skia.shaders", "SkSL::ToWGSL");
4587     SkASSERT(caps != nullptr);
4588 
4589     program.fContext->fErrors->setSource(*program.fSource);
4590     bool result;
4591     if (validateWGSL) {
4592         StringStream wgsl;
4593         WGSLCodeGenerator cg(program.fContext.get(), caps, &program, &wgsl, pp, isc);
4594         result = cg.generateCode();
4595         if (result) {
4596             std::string_view wgslBytes = wgsl.str();
4597             std::string warnings;
4598             result = validateWGSL(*program.fContext->fErrors, wgslBytes, &warnings);
4599             if (!warnings.empty()) {
4600                 out.writeText("/* Tint reported warnings. */\n\n");
4601             }
4602             out.write(wgslBytes.data(), wgslBytes.size());
4603         }
4604     } else {
4605         WGSLCodeGenerator cg(program.fContext.get(), caps, &program, &out, pp, isc);
4606         result = cg.generateCode();
4607     }
4608     program.fContext->fErrors->setSource(std::string_view());
4609 
4610     return result;
4611 }
4612 
ToWGSL(Program & program,const ShaderCaps * caps,OutputStream & out)4613 bool ToWGSL(Program& program, const ShaderCaps* caps, OutputStream& out) {
4614 #if defined(SK_DEBUG)
4615     constexpr PrettyPrint defaultPrintOpts = PrettyPrint::kYes;
4616 #else
4617     constexpr PrettyPrint defaultPrintOpts = PrettyPrint::kNo;
4618 #endif
4619     return ToWGSL(program, caps, out, defaultPrintOpts, IncludeSyntheticCode::kNo, nullptr);
4620 }
4621 
ToWGSL(Program & program,const ShaderCaps * caps,std::string * out)4622 bool ToWGSL(Program& program, const ShaderCaps* caps, std::string* out) {
4623     StringStream buffer;
4624     if (!ToWGSL(program, caps, buffer)) {
4625         return false;
4626     }
4627     *out = buffer.str();
4628     return true;
4629 }
4630 
4631 }  // namespace SkSL
4632