xref: /aosp_15_r20/external/skia/src/sksl/ir/SkSLSwizzle.cpp (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLSwizzle.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/private/base/SkTArray.h"
12 #include "src/sksl/SkSLAnalysis.h"
13 #include "src/sksl/SkSLConstantFolder.h"
14 #include "src/sksl/SkSLContext.h"
15 #include "src/sksl/SkSLDefines.h"
16 #include "src/sksl/SkSLErrorReporter.h"
17 #include "src/sksl/SkSLOperator.h"
18 #include "src/sksl/SkSLString.h"
19 #include "src/sksl/ir/SkSLConstructorCompound.h"
20 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
21 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
22 #include "src/sksl/ir/SkSLConstructorSplat.h"
23 #include "src/sksl/ir/SkSLLiteral.h"
24 
25 #include <algorithm>
26 #include <cstdint>
27 #include <cstring>
28 #include <optional>
29 
30 using namespace skia_private;
31 
32 namespace SkSL {
33 
validate_swizzle_domain(const ComponentArray & fields)34 static bool validate_swizzle_domain(const ComponentArray& fields) {
35     enum SwizzleDomain {
36         kCoordinate,
37         kColor,
38         kUV,
39         kRectangle,
40     };
41 
42     std::optional<SwizzleDomain> domain;
43 
44     for (int8_t field : fields) {
45         SwizzleDomain fieldDomain;
46         switch (field) {
47             case SwizzleComponent::X:
48             case SwizzleComponent::Y:
49             case SwizzleComponent::Z:
50             case SwizzleComponent::W:
51                 fieldDomain = kCoordinate;
52                 break;
53             case SwizzleComponent::R:
54             case SwizzleComponent::G:
55             case SwizzleComponent::B:
56             case SwizzleComponent::A:
57                 fieldDomain = kColor;
58                 break;
59             case SwizzleComponent::S:
60             case SwizzleComponent::T:
61             case SwizzleComponent::P:
62             case SwizzleComponent::Q:
63                 fieldDomain = kUV;
64                 break;
65             case SwizzleComponent::UL:
66             case SwizzleComponent::UT:
67             case SwizzleComponent::UR:
68             case SwizzleComponent::UB:
69                 fieldDomain = kRectangle;
70                 break;
71             case SwizzleComponent::ZERO:
72             case SwizzleComponent::ONE:
73                 continue;
74             default:
75                 return false;
76         }
77 
78         if (!domain.has_value()) {
79             domain = fieldDomain;
80         } else if (domain != fieldDomain) {
81             return false;
82         }
83     }
84 
85     return true;
86 }
87 
mask_char(int8_t component)88 static char mask_char(int8_t component) {
89     switch (component) {
90         case SwizzleComponent::X:    return 'x';
91         case SwizzleComponent::Y:    return 'y';
92         case SwizzleComponent::Z:    return 'z';
93         case SwizzleComponent::W:    return 'w';
94         case SwizzleComponent::R:    return 'r';
95         case SwizzleComponent::G:    return 'g';
96         case SwizzleComponent::B:    return 'b';
97         case SwizzleComponent::A:    return 'a';
98         case SwizzleComponent::S:    return 's';
99         case SwizzleComponent::T:    return 't';
100         case SwizzleComponent::P:    return 'p';
101         case SwizzleComponent::Q:    return 'q';
102         case SwizzleComponent::UL:   return 'L';
103         case SwizzleComponent::UT:   return 'T';
104         case SwizzleComponent::UR:   return 'R';
105         case SwizzleComponent::UB:   return 'B';
106         case SwizzleComponent::ZERO: return '0';
107         case SwizzleComponent::ONE:  return '1';
108         default: SkUNREACHABLE;
109     }
110 }
111 
MaskString(const ComponentArray & components)112 std::string Swizzle::MaskString(const ComponentArray& components) {
113     std::string result;
114     for (int8_t component : components) {
115         result += mask_char(component);
116     }
117     return result;
118 }
119 
optimize_constructor_swizzle(const Context & context,Position pos,const ConstructorCompound & base,ComponentArray components)120 static std::unique_ptr<Expression> optimize_constructor_swizzle(const Context& context,
121                                                                 Position pos,
122                                                                 const ConstructorCompound& base,
123                                                                 ComponentArray components) {
124     auto baseArguments = base.argumentSpan();
125     std::unique_ptr<Expression> replacement;
126     const Type& exprType = base.type();
127     const Type& componentType = exprType.componentType();
128     int swizzleSize = components.size();
129 
130     // Swizzles can duplicate some elements and discard others, e.g.
131     // `half4(1, 2, 3, 4).xxz` --> `half3(1, 1, 3)`. However, there are constraints:
132     // - Expressions with side effects need to occur exactly once, even if they would otherwise be
133     //   swizzle-eliminated
134     // - Non-trivial expressions should not be repeated, but elimination is OK.
135     //
136     // Look up the argument for the constructor at each index. This is typically simple but for
137     // weird cases like `half4(bar.yz, half2(foo))`, it can be harder than it seems. This example
138     // would result in:
139     //     argMap[0] = {.fArgIndex = 0, .fComponent = 0}   (bar.yz     .x)
140     //     argMap[1] = {.fArgIndex = 0, .fComponent = 1}   (bar.yz     .y)
141     //     argMap[2] = {.fArgIndex = 1, .fComponent = 0}   (half2(foo) .x)
142     //     argMap[3] = {.fArgIndex = 1, .fComponent = 1}   (half2(foo) .y)
143     struct ConstructorArgMap {
144         int8_t fArgIndex;
145         int8_t fComponent;
146     };
147 
148     int numConstructorArgs = base.type().columns();
149     ConstructorArgMap argMap[4] = {};
150     int writeIdx = 0;
151     for (int argIdx = 0; argIdx < (int)baseArguments.size(); ++argIdx) {
152         const Expression& arg = *baseArguments[argIdx];
153         const Type& argType = arg.type();
154 
155         if (!argType.isScalar() && !argType.isVector()) {
156             return nullptr;
157         }
158 
159         int argSlots = argType.slotCount();
160         for (int componentIdx = 0; componentIdx < argSlots; ++componentIdx) {
161             argMap[writeIdx].fArgIndex = argIdx;
162             argMap[writeIdx].fComponent = componentIdx;
163             ++writeIdx;
164         }
165     }
166     SkASSERT(writeIdx == numConstructorArgs);
167 
168     // Count up the number of times each constructor argument is used by the swizzle.
169     //    `half4(bar.yz, half2(foo)).xwxy` -> { 3, 1 }
170     // - bar.yz    is referenced 3 times, by `.x_xy`
171     // - half(foo) is referenced 1 time,  by `._w__`
172     int8_t exprUsed[4] = {};
173     for (int8_t c : components) {
174         exprUsed[argMap[c].fArgIndex]++;
175     }
176 
177     for (int index = 0; index < numConstructorArgs; ++index) {
178         int8_t constructorArgIndex = argMap[index].fArgIndex;
179         const Expression& baseArg = *baseArguments[constructorArgIndex];
180 
181         // Check that non-trivial expressions are not swizzled in more than once.
182         if (exprUsed[constructorArgIndex] > 1 && !Analysis::IsTrivialExpression(baseArg)) {
183             return nullptr;
184         }
185         // Check that side-effect-bearing expressions are swizzled in exactly once.
186         if (exprUsed[constructorArgIndex] != 1 && Analysis::HasSideEffects(baseArg)) {
187             return nullptr;
188         }
189     }
190 
191     struct ReorderedArgument {
192         int8_t fArgIndex;
193         ComponentArray fComponents;
194     };
195     STArray<4, ReorderedArgument> reorderedArgs;
196     for (int8_t c : components) {
197         const ConstructorArgMap& argument = argMap[c];
198         const Expression& baseArg = *baseArguments[argument.fArgIndex];
199 
200         if (baseArg.type().isScalar()) {
201             // This argument is a scalar; add it to the list as-is.
202             SkASSERT(argument.fComponent == 0);
203             reorderedArgs.push_back({argument.fArgIndex,
204                                      ComponentArray{}});
205         } else {
206             // This argument is a component from a vector.
207             SkASSERT(baseArg.type().isVector());
208             SkASSERT(argument.fComponent < baseArg.type().columns());
209             if (reorderedArgs.empty() ||
210                 reorderedArgs.back().fArgIndex != argument.fArgIndex) {
211                 // This can't be combined with the previous argument. Add a new one.
212                 reorderedArgs.push_back({argument.fArgIndex,
213                                          ComponentArray{argument.fComponent}});
214             } else {
215                 // Since we know this argument uses components, it should already have at least one
216                 // component set.
217                 SkASSERT(!reorderedArgs.back().fComponents.empty());
218                 // Build up the current argument with one more component.
219                 reorderedArgs.back().fComponents.push_back(argument.fComponent);
220             }
221         }
222     }
223 
224     // Convert our reordered argument list to an actual array of expressions, with the new order and
225     // any new inner swizzles that need to be applied.
226     ExpressionArray newArgs;
227     newArgs.reserve_exact(swizzleSize);
228     for (const ReorderedArgument& reorderedArg : reorderedArgs) {
229         std::unique_ptr<Expression> newArg = baseArguments[reorderedArg.fArgIndex]->clone();
230 
231         if (reorderedArg.fComponents.empty()) {
232             newArgs.push_back(std::move(newArg));
233         } else {
234             newArgs.push_back(Swizzle::Make(context, pos, std::move(newArg),
235                                             reorderedArg.fComponents));
236         }
237     }
238 
239     // Wrap the new argument list in a compound constructor.
240     return ConstructorCompound::Make(context,
241                                      pos,
242                                      componentType.toCompound(context, swizzleSize, /*rows=*/1),
243                                      std::move(newArgs));
244 }
245 
Convert(const Context & context,Position pos,Position maskPos,std::unique_ptr<Expression> base,std::string_view componentString)246 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
247                                              Position pos,
248                                              Position maskPos,
249                                              std::unique_ptr<Expression> base,
250                                              std::string_view componentString) {
251     if (componentString.size() > 4) {
252         context.fErrors->error(Position::Range(maskPos.startOffset() + 4,
253                                                maskPos.endOffset()),
254                                "too many components in swizzle mask");
255         return nullptr;
256     }
257 
258     // Convert the component string into an equivalent array.
259     ComponentArray components;
260     for (size_t i = 0; i < componentString.length(); ++i) {
261         char field = componentString[i];
262         switch (field) {
263             case '0': components.push_back(SwizzleComponent::ZERO); break;
264             case '1': components.push_back(SwizzleComponent::ONE);  break;
265             case 'x': components.push_back(SwizzleComponent::X);    break;
266             case 'r': components.push_back(SwizzleComponent::R);    break;
267             case 's': components.push_back(SwizzleComponent::S);    break;
268             case 'L': components.push_back(SwizzleComponent::UL);   break;
269             case 'y': components.push_back(SwizzleComponent::Y);    break;
270             case 'g': components.push_back(SwizzleComponent::G);    break;
271             case 't': components.push_back(SwizzleComponent::T);    break;
272             case 'T': components.push_back(SwizzleComponent::UT);   break;
273             case 'z': components.push_back(SwizzleComponent::Z);    break;
274             case 'b': components.push_back(SwizzleComponent::B);    break;
275             case 'p': components.push_back(SwizzleComponent::P);    break;
276             case 'R': components.push_back(SwizzleComponent::UR);   break;
277             case 'w': components.push_back(SwizzleComponent::W);    break;
278             case 'a': components.push_back(SwizzleComponent::A);    break;
279             case 'q': components.push_back(SwizzleComponent::Q);    break;
280             case 'B': components.push_back(SwizzleComponent::UB);   break;
281             default:
282                 context.fErrors->error(Position::Range(maskPos.startOffset() + i,
283                                                        maskPos.startOffset() + i + 1),
284                                        String::printf("invalid swizzle component '%c'", field));
285                 return nullptr;
286         }
287     }
288 
289     if (!validate_swizzle_domain(components)) {
290         context.fErrors->error(maskPos, "invalid swizzle mask '" + MaskString(components) + "'");
291         return nullptr;
292     }
293 
294     const Type& baseType = base->type().scalarTypeForLiteral();
295 
296     if (!baseType.isVector() && !baseType.isScalar()) {
297         context.fErrors->error(pos, "cannot swizzle value of type '" +
298                                     baseType.displayName() + "'");
299         return nullptr;
300     }
301 
302     ComponentArray maskComponents;
303     bool foundXYZW = false;
304     for (int i = 0; i < components.size(); ++i) {
305         switch (components[i]) {
306             case SwizzleComponent::ZERO:
307             case SwizzleComponent::ONE:
308                 // Skip over constant fields for now.
309                 break;
310             case SwizzleComponent::X:
311             case SwizzleComponent::R:
312             case SwizzleComponent::S:
313             case SwizzleComponent::UL:
314                 foundXYZW = true;
315                 maskComponents.push_back(SwizzleComponent::X);
316                 break;
317             case SwizzleComponent::Y:
318             case SwizzleComponent::G:
319             case SwizzleComponent::T:
320             case SwizzleComponent::UT:
321                 foundXYZW = true;
322                 if (baseType.columns() >= 2) {
323                     maskComponents.push_back(SwizzleComponent::Y);
324                     break;
325                 }
326                 [[fallthrough]];
327             case SwizzleComponent::Z:
328             case SwizzleComponent::B:
329             case SwizzleComponent::P:
330             case SwizzleComponent::UR:
331                 foundXYZW = true;
332                 if (baseType.columns() >= 3) {
333                     maskComponents.push_back(SwizzleComponent::Z);
334                     break;
335                 }
336                 [[fallthrough]];
337             case SwizzleComponent::W:
338             case SwizzleComponent::A:
339             case SwizzleComponent::Q:
340             case SwizzleComponent::UB:
341                 foundXYZW = true;
342                 if (baseType.columns() >= 4) {
343                     maskComponents.push_back(SwizzleComponent::W);
344                     break;
345                 }
346                 [[fallthrough]];
347             default:
348                 // The swizzle component references a field that doesn't exist in the base type.
349                 context.fErrors->error(Position::Range(maskPos.startOffset() + i,
350                                                        maskPos.startOffset() + i + 1),
351                                        String::printf("invalid swizzle component '%c'",
352                                                       mask_char(components[i])));
353                 return nullptr;
354         }
355     }
356 
357     if (!foundXYZW) {
358         context.fErrors->error(maskPos, "swizzle must refer to base expression");
359         return nullptr;
360     }
361 
362     // Coerce literals in expressions such as `(12345).xxx` to their actual type.
363     base = baseType.coerceExpression(std::move(base), context);
364     if (!base) {
365         return nullptr;
366     }
367 
368     // Swizzles are complicated due to constant components. The most difficult case is a mask like
369     // '.x1w0'. A naive approach might turn that into 'float4(base.x, 1, base.w, 0)', but that
370     // evaluates 'base' twice. We instead group the swizzle mask ('xw') and constants ('1, 0')
371     // together and use a secondary swizzle to put them back into the right order, so in this case
372     // we end up with 'float4(base.xw, 1, 0).xzyw'.
373     //
374     // First, we need a vector expression that is the non-constant portion of the swizzle, packed:
375     //   scalar.xxx  -> type3(scalar)
376     //   scalar.x0x0 -> type2(scalar)
377     //   vector.zyx  -> vector.zyx
378     //   vector.x0y0 -> vector.xy
379     std::unique_ptr<Expression> expr = Swizzle::Make(context, pos, std::move(base), maskComponents);
380 
381     // If we have processed the entire swizzle, we're done.
382     if (maskComponents.size() == components.size()) {
383         return expr;
384     }
385 
386     // Now we create a constructor that has the correct number of elements for the final swizzle,
387     // with all fields at the start. It's not finished yet; constants we need will be added below.
388     //   scalar.x0x0 -> type4(type2(x), ...)
389     //   vector.y111 -> type4(vector.y, ...)
390     //   vector.z10x -> type4(vector.zx, ...)
391     //
392     // The constructor will have at most three arguments: { base expr, constant 0, constant 1 }
393     ExpressionArray constructorArgs;
394     constructorArgs.reserve_exact(3);
395     constructorArgs.push_back(std::move(expr));
396 
397     // Apply another swizzle to shuffle the constants into the correct place. Any constant values we
398     // need are also tacked on to the end of the constructor.
399     //   scalar.x0x0 -> type4(type2(x), 0).xyxy
400     //   vector.y111 -> type2(vector.y, 1).xyyy
401     //   vector.z10x -> type4(vector.zx, 1, 0).xzwy
402     const Type* scalarType = &baseType.componentType();
403     ComponentArray swizzleComponents;
404     int maskFieldIdx = 0;
405     int constantFieldIdx = maskComponents.size();
406     int constantZeroIdx = -1, constantOneIdx = -1;
407 
408     for (int i = 0; i < components.size(); i++) {
409         switch (components[i]) {
410             case SwizzleComponent::ZERO:
411                 if (constantZeroIdx == -1) {
412                     // Synthesize a '0' argument at the end of the constructor.
413                     constructorArgs.push_back(Literal::Make(pos, /*value=*/0, scalarType));
414                     constantZeroIdx = constantFieldIdx++;
415                 }
416                 swizzleComponents.push_back(constantZeroIdx);
417                 break;
418             case SwizzleComponent::ONE:
419                 if (constantOneIdx == -1) {
420                     // Synthesize a '1' argument at the end of the constructor.
421                     constructorArgs.push_back(Literal::Make(pos, /*value=*/1, scalarType));
422                     constantOneIdx = constantFieldIdx++;
423                 }
424                 swizzleComponents.push_back(constantOneIdx);
425                 break;
426             default:
427                 // The non-constant fields are already in the expected order.
428                 swizzleComponents.push_back(maskFieldIdx++);
429                 break;
430         }
431     }
432 
433     expr = ConstructorCompound::Make(context, pos,
434                                      scalarType->toCompound(context, constantFieldIdx, /*rows=*/1),
435                                      std::move(constructorArgs));
436 
437     // Create (and potentially optimize-away) the resulting swizzle-expression.
438     return Swizzle::Make(context, pos, std::move(expr), swizzleComponents);
439 }
440 
IsIdentity(const ComponentArray & components)441 bool Swizzle::IsIdentity(const ComponentArray& components) {
442     for (int index = 0; index < components.size(); ++index) {
443         if (components[index] != index) {
444             return false;
445         }
446     }
447     return true;
448 }
449 
Make(const Context & context,Position pos,std::unique_ptr<Expression> expr,ComponentArray components)450 std::unique_ptr<Expression> Swizzle::Make(const Context& context,
451                                           Position pos,
452                                           std::unique_ptr<Expression> expr,
453                                           ComponentArray components) {
454     const Type& exprType = expr->type();
455     SkASSERTF(exprType.isVector() || exprType.isScalar(),
456               "cannot swizzle type '%s'", exprType.description().c_str());
457     SkASSERT(components.size() >= 1 && components.size() <= 4);
458 
459     // Confirm that the component array only contains X/Y/Z/W. (Call MakeWith01 if you want support
460     // for ZERO and ONE. Once initial IR generation is complete, no swizzles should have zeros or
461     // ones in them.)
462     SkASSERT(std::all_of(components.begin(), components.end(), [](int8_t component) {
463         return component >= SwizzleComponent::X &&
464                component <= SwizzleComponent::W;
465     }));
466 
467     // SkSL supports splatting a scalar via `scalar.xxxx`, but not all versions of GLSL allow this.
468     // Replace swizzles with equivalent splat constructors (`scalar.xxx` --> `half3(value)`).
469     if (exprType.isScalar()) {
470         return ConstructorSplat::Make(context, pos,
471                                       exprType.toCompound(context, components.size(), /*rows=*/1),
472                                       std::move(expr));
473     }
474 
475     // Detect identity swizzles like `color.rgba` and optimize them away.
476     if (components.size() == exprType.columns() && IsIdentity(components)) {
477         expr->fPosition = pos;
478         return expr;
479     }
480 
481     // Optimize swizzles of swizzles, e.g. replace `foo.argb.rggg` with `foo.arrr`.
482     if (expr->is<Swizzle>()) {
483         Swizzle& base = expr->as<Swizzle>();
484         ComponentArray combined;
485         for (int8_t c : components) {
486             combined.push_back(base.components()[c]);
487         }
488 
489         // It may actually be possible to further simplify this swizzle. Go again.
490         // (e.g. `color.abgr.abgr` --> `color.rgba` --> `color`.)
491         return Swizzle::Make(context, pos, std::move(base.base()), combined);
492     }
493 
494     // If we are swizzling a constant expression, we can use its value instead here (so that
495     // swizzles like `colorWhite.x` can be simplified to `1`).
496     const Expression* value = ConstantFolder::GetConstantValueForVariable(*expr);
497 
498     // `half4(scalar).zyy` can be optimized to `half3(scalar)`, and `half3(scalar).y` can be
499     // optimized to just `scalar`. The swizzle components don't actually matter, as every field
500     // in a splat constructor holds the same value.
501     if (value->is<ConstructorSplat>()) {
502         const ConstructorSplat& splat = value->as<ConstructorSplat>();
503         return ConstructorSplat::Make(
504                 context, pos,
505                 splat.type().componentType().toCompound(context, components.size(), /*rows=*/1),
506                 splat.argument()->clone());
507     }
508 
509     // Swizzles on casts, like `half4(myFloat4).zyy`, can optimize to `half3(myFloat4.zyy)`.
510     if (value->is<ConstructorCompoundCast>()) {
511         const ConstructorCompoundCast& cast = value->as<ConstructorCompoundCast>();
512         const Type& castType = cast.type().componentType().toCompound(context, components.size(),
513                                                                       /*rows=*/1);
514         std::unique_ptr<Expression> swizzled = Swizzle::Make(context, pos, cast.argument()->clone(),
515                                                              std::move(components));
516         return (castType.columns() > 1)
517                        ? ConstructorCompoundCast::Make(context, pos, castType, std::move(swizzled))
518                        : ConstructorScalarCast::Make(context, pos, castType, std::move(swizzled));
519     }
520 
521     // Swizzles on compound constructors, like `half4(1, 2, 3, 4).yw`, can become `half2(2, 4)`.
522     if (value->is<ConstructorCompound>()) {
523         const ConstructorCompound& ctor = value->as<ConstructorCompound>();
524         if (auto replacement = optimize_constructor_swizzle(context, pos, ctor, components)) {
525             return replacement;
526         }
527     }
528 
529     // The swizzle could not be simplified, so apply the requested swizzle to the base expression.
530     return std::make_unique<Swizzle>(context, pos, std::move(expr), components);
531 }
532 
MakeExact(const Context & context,Position pos,std::unique_ptr<Expression> expr,ComponentArray components)533 std::unique_ptr<Expression> Swizzle::MakeExact(const Context& context,
534                                                Position pos,
535                                                std::unique_ptr<Expression> expr,
536                                                ComponentArray components) {
537     SkASSERTF(expr->type().isVector() || expr->type().isScalar(),
538               "cannot swizzle type '%s'", expr->type().description().c_str());
539     SkASSERT(components.size() >= 1 && components.size() <= 4);
540 
541     // Confirm that the component array only contains X/Y/Z/W.
542     SkASSERT(std::all_of(components.begin(), components.end(), [](int8_t component) {
543         return component >= SwizzleComponent::X &&
544                component <= SwizzleComponent::W;
545     }));
546 
547     return std::make_unique<Swizzle>(context, pos, std::move(expr), components);
548 }
549 
description(OperatorPrecedence) const550 std::string Swizzle::description(OperatorPrecedence) const {
551     return this->base()->description(OperatorPrecedence::kPostfix) + "." +
552            MaskString(this->components());
553 }
554 
555 }  // namespace SkSL
556