xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/RewritePixelLocalStorage.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_ops/RewritePixelLocalStorage.h"
8 
9 #include "common/angleutils.h"
10 #include "compiler/translator/StaticType.h"
11 #include "compiler/translator/SymbolTable.h"
12 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
13 #include "compiler/translator/tree_util/BuiltIn.h"
14 #include "compiler/translator/tree_util/FindMain.h"
15 #include "compiler/translator/tree_util/IntermNode_util.h"
16 #include "compiler/translator/tree_util/IntermTraverse.h"
17 
18 namespace sh
19 {
20 namespace
21 {
DataTypeOfPLSType(TBasicType plsType)22 constexpr static TBasicType DataTypeOfPLSType(TBasicType plsType)
23 {
24     switch (plsType)
25     {
26         case EbtPixelLocalANGLE:
27             return EbtFloat;
28         case EbtIPixelLocalANGLE:
29             return EbtInt;
30         case EbtUPixelLocalANGLE:
31             return EbtUInt;
32         default:
33             UNREACHABLE();
34             return EbtVoid;
35     }
36 }
37 
DataTypeOfImageType(TBasicType imageType)38 constexpr static TBasicType DataTypeOfImageType(TBasicType imageType)
39 {
40     switch (imageType)
41     {
42         case EbtImage2D:
43             return EbtFloat;
44         case EbtIImage2D:
45             return EbtInt;
46         case EbtUImage2D:
47             return EbtUInt;
48         default:
49             UNREACHABLE();
50             return EbtVoid;
51     }
52 }
53 
54 // Maps PLS symbols to a backing store.
55 template <typename T>
56 class PLSBackingStoreMap
57 {
58   public:
59     // Sets the given variable as the backing storage for the plsSymbol's binding point. An entry
60     // must not already exist in the map for this binding point.
insertNew(TIntermSymbol * plsSymbol,const T & backingStore)61     void insertNew(TIntermSymbol *plsSymbol, const T &backingStore)
62     {
63         ASSERT(plsSymbol);
64         ASSERT(IsPixelLocal(plsSymbol->getBasicType()));
65         int binding = plsSymbol->getType().getLayoutQualifier().binding;
66         ASSERT(binding >= 0);
67         auto result = mMap.insert({binding, backingStore});
68         ASSERT(result.second);  // Ensure an image didn't already exist for this symbol.
69     }
70 
71     // Looks up the backing store for the given plsSymbol's binding point. An entry must already
72     // exist in the map for this binding point.
find(TIntermSymbol * plsSymbol)73     const T &find(TIntermSymbol *plsSymbol)
74     {
75         ASSERT(plsSymbol);
76         ASSERT(IsPixelLocal(plsSymbol->getBasicType()));
77         int binding = plsSymbol->getType().getLayoutQualifier().binding;
78         ASSERT(binding >= 0);
79         auto iter = mMap.find(binding);
80         ASSERT(iter != mMap.end());  // Ensure PLSImages already exist for this symbol.
81         return iter->second;
82     }
83 
bindingOrderedMap() const84     const std::map<int, T> &bindingOrderedMap() const { return mMap; }
85 
86   private:
87     // Use std::map so the backing stores are ordered by binding when we iterate.
88     std::map<int, T> mMap;
89 };
90 
91 // Base class for rewriting high level PLS operations to AST operations specified by
92 // ShPixelLocalStorageType.
93 class RewritePLSTraverser : public TIntermTraverser
94 {
95   public:
RewritePLSTraverser(TCompiler * compiler,TSymbolTable & symbolTable,const ShCompileOptions & compileOptions,int shaderVersion)96     RewritePLSTraverser(TCompiler *compiler,
97                         TSymbolTable &symbolTable,
98                         const ShCompileOptions &compileOptions,
99                         int shaderVersion)
100         : TIntermTraverser(true, false, false, &symbolTable),
101           mCompiler(compiler),
102           mCompileOptions(&compileOptions),
103           mShaderVersion(shaderVersion)
104     {}
105 
visitDeclaration(Visit,TIntermDeclaration * decl)106     bool visitDeclaration(Visit, TIntermDeclaration *decl) override
107     {
108         TIntermTyped *declVariable = (decl->getSequence())->front()->getAsTyped();
109         ASSERT(declVariable);
110 
111         if (!IsPixelLocal(declVariable->getBasicType()))
112         {
113             return true;
114         }
115 
116         // PLS is not allowed in arrays.
117         ASSERT(!declVariable->isArray());
118 
119         // This visitDeclaration doesn't get called for function arguments, and opaque types can
120         // otherwise only be uniforms.
121         ASSERT(declVariable->getQualifier() == EvqUniform);
122 
123         TIntermSymbol *plsSymbol = declVariable->getAsSymbolNode();
124         ASSERT(plsSymbol);
125 
126         visitPLSDeclaration(plsSymbol);
127 
128         return false;
129     }
130 
visitAggregate(Visit,TIntermAggregate * aggregate)131     bool visitAggregate(Visit, TIntermAggregate *aggregate) override
132     {
133         if (!BuiltInGroup::IsPixelLocal(aggregate->getOp()))
134         {
135             return true;
136         }
137 
138         const TIntermSequence &args = *aggregate->getSequence();
139         ASSERT(args.size() >= 1);
140         TIntermSymbol *plsSymbol = args[0]->getAsSymbolNode();
141 
142         // Rewrite pixelLocalLoadANGLE -> imageLoad.
143         if (aggregate->getOp() == EOpPixelLocalLoadANGLE)
144         {
145             visitPLSLoad(plsSymbol);
146             return false;  // No need to recurse since this node is being dropped.
147         }
148 
149         // Rewrite pixelLocalStoreANGLE -> imageStore.
150         if (aggregate->getOp() == EOpPixelLocalStoreANGLE)
151         {
152             // Also hoist the 'value' expression into a temp. In the event of
153             // "pixelLocalStoreANGLE(..., pixelLocalLoadANGLE(...))", this ensures the load occurs
154             // _before_ any potential barriers required by the subclass.
155             //
156             // NOTE: It is generally unsafe to hoist function arguments due to short circuiting,
157             // e.g., "if (false && function(...))", but pixelLocalStoreANGLE returns type void, so
158             // it is safe in this particular case.
159             TType *valueType    = new TType(DataTypeOfPLSType(plsSymbol->getBasicType()),
160                                             plsSymbol->getPrecision(), EvqTemporary, 4);
161             TVariable *valueVar = CreateTempVariable(mSymbolTable, valueType);
162             TIntermDeclaration *valueDecl =
163                 CreateTempInitDeclarationNode(valueVar, args[1]->getAsTyped());
164             valueDecl->traverse(this);  // Rewrite any potential pixelLocalLoadANGLEs in valueDecl.
165             insertStatementInParentBlock(valueDecl);
166 
167             visitPLSStore(plsSymbol, valueVar);
168             return false;  // No need to recurse since this node is being dropped.
169         }
170 
171         return true;
172     }
173 
174     // Called after rewrite. Injects one-time setup code that needs to run before any PLS accesses.
injectPrePLSCode(TCompiler *,TSymbolTable &,const ShCompileOptions &,TIntermBlock * mainBody,size_t plsBeginPosition)175     virtual void injectPrePLSCode(TCompiler *,
176                                   TSymbolTable &,
177                                   const ShCompileOptions &,
178                                   TIntermBlock *mainBody,
179                                   size_t plsBeginPosition)
180     {}
181 
182     // Called after rewrite. Injects one-time finalization code that needs to run after all PLS.
injectPostPLSCode(TCompiler *,TSymbolTable &,const ShCompileOptions &,TIntermBlock * mainBody,size_t plsEndPosition)183     virtual void injectPostPLSCode(TCompiler *,
184                                    TSymbolTable &,
185                                    const ShCompileOptions &,
186                                    TIntermBlock *mainBody,
187                                    size_t plsEndPosition)
188     {}
189 
190     // Called after all other operations have completed.
injectPixelCoordInitializationCodeIfNeeded(TCompiler * compiler,TIntermBlock * root,TIntermBlock * mainBody)191     void injectPixelCoordInitializationCodeIfNeeded(TCompiler *compiler,
192                                                     TIntermBlock *root,
193                                                     TIntermBlock *mainBody)
194     {
195         if (mGlobalPixelCoord)
196         {
197             // Initialize the global pixel coord at the beginning of main():
198             //
199             //     pixelCoord = ivec2(floor(gl_FragCoord.xy));
200             //
201             TIntermTyped *exp;
202             exp = ReferenceBuiltInVariable(ImmutableString("gl_FragCoord"), *mSymbolTable,
203                                            mShaderVersion);
204             exp = CreateSwizzle(exp, 0, 1);
205             exp = CreateBuiltInFunctionCallNode("floor", {exp}, *mSymbolTable, mShaderVersion);
206             exp = TIntermAggregate::CreateConstructor(TType(EbtInt, 2), {exp});
207             exp = CreateTempAssignmentNode(mGlobalPixelCoord, exp);
208             mainBody->insertStatement(0, exp);
209         }
210     }
211 
212   protected:
213     virtual void visitPLSDeclaration(TIntermSymbol *plsSymbol)             = 0;
214     virtual void visitPLSLoad(TIntermSymbol *plsSymbol)                    = 0;
215     virtual void visitPLSStore(TIntermSymbol *plsSymbol, TVariable *value) = 0;
216 
217     // Inserts a global to hold the pixel coordinate as soon as we see PLS declared. This will be
218     // initialized at the beginning of main().
ensureGlobalPixelCoordDeclared()219     void ensureGlobalPixelCoordDeclared()
220     {
221         if (!mGlobalPixelCoord)
222         {
223             TType *coordType  = new TType(EbtInt, EbpHigh, EvqGlobal, 2);
224             mGlobalPixelCoord = CreateTempVariable(mSymbolTable, coordType);
225             insertStatementInParentBlock(CreateTempDeclarationNode(mGlobalPixelCoord));
226         }
227     }
228 
229     // anglebug.com/42265993: Storing to integer formats with larger-than-representable values has
230     // different behavior on the various APIs.
231     //
232     // This method clamps sub-32-bit integers to the min/max representable values of their format.
clampPLSVarIfNeeded(TVariable * plsVar,TLayoutImageInternalFormat plsFormat)233     void clampPLSVarIfNeeded(TVariable *plsVar, TLayoutImageInternalFormat plsFormat)
234     {
235         switch (plsFormat)
236         {
237             case EiifRGBA8I:
238             {
239                 // Clamp r,g,b,a to their min/max 8-bit values:
240                 //
241                 //     plsVar = clamp(plsVar, -128, 127) & 0xff
242                 //
243                 TIntermTyped *newPLSValue = CreateBuiltInFunctionCallNode(
244                     "clamp",
245                     {new TIntermSymbol(plsVar), CreateIndexNode(-128), CreateIndexNode(127)},
246                     *mSymbolTable, mShaderVersion);
247                 insertStatementInParentBlock(CreateTempAssignmentNode(plsVar, newPLSValue));
248                 break;
249             }
250             case EiifRGBA8UI:
251             {
252                 // Clamp r,g,b,a to their max 8-bit values:
253                 //
254                 //     plsVar = min(plsVar, 255)
255                 //
256                 TIntermTyped *newPLSValue = CreateBuiltInFunctionCallNode(
257                     "min", {new TIntermSymbol(plsVar), CreateUIntNode(255)}, *mSymbolTable,
258                     mShaderVersion);
259                 insertStatementInParentBlock(CreateTempAssignmentNode(plsVar, newPLSValue));
260                 break;
261             }
262             default:
263                 break;
264         }
265     }
266 
267     // Expands an expression to 4 components, filling in the missing components with [0, 0, 0, 1].
Expand(TIntermTyped * expr)268     static TIntermTyped *Expand(TIntermTyped *expr)
269     {
270         const TType &type = expr->getType();
271         ASSERT(type.getNominalSize() == 1 || type.getNominalSize() == 4);
272         if (type.getNominalSize() == 1)
273         {
274             switch (type.getBasicType())
275             {
276                 case EbtFloat:
277                     expr = TIntermAggregate::CreateConstructor(  // "vec4(r, 0, 0, 1)"
278                         TType(EbtFloat, 4),
279                         {expr, CreateFloatNode(0, EbpLow), CreateFloatNode(0, EbpLow),
280                          CreateFloatNode(1, EbpLow)});
281                     break;
282                 case EbtUInt:
283                     expr = TIntermAggregate::CreateConstructor(  // "uvec4(r, 0, 0, 1)"
284                         TType(EbtUInt, 4),
285                         {expr, CreateUIntNode(0), CreateUIntNode(0), CreateUIntNode(1)});
286                     break;
287                 default:
288                     UNREACHABLE();
289                     break;
290             }
291         }
292         return expr;
293     }
294 
Expand(TVariable * var)295     static TIntermTyped *Expand(TVariable *var) { return Expand(new TIntermSymbol(var)); }
296 
297     // Returns an expression that swizzles a variable down to 'n' components.
Swizzle(TVariable * var,int n)298     static TIntermTyped *Swizzle(TVariable *var, int n)
299     {
300         TIntermTyped *swizzled = new TIntermSymbol(var);
301         if (var->getType().getNominalSize() != n)
302         {
303             ASSERT(var->getType().getNominalSize() > n);
304             TVector swizzleOffsets{0, 1, 2, 3};
305             swizzleOffsets.resize(n);
306             swizzled = new TIntermSwizzle(swizzled, swizzleOffsets);
307         }
308         return swizzled;
309     }
310 
311     const TCompiler *const mCompiler;
312     const ShCompileOptions *const mCompileOptions;
313     const int mShaderVersion;
314 
315     // Stores the shader invocation's pixel coordinate as "ivec2(floor(gl_FragCoord.xy))".
316     TVariable *mGlobalPixelCoord = nullptr;
317 };
318 
319 // Rewrites high level PLS operations to shader image operations.
320 class RewritePLSToImagesTraverser : public RewritePLSTraverser
321 {
322   public:
RewritePLSToImagesTraverser(TCompiler * compiler,TSymbolTable & symbolTable,const ShCompileOptions & compileOptions,int shaderVersion)323     RewritePLSToImagesTraverser(TCompiler *compiler,
324                                 TSymbolTable &symbolTable,
325                                 const ShCompileOptions &compileOptions,
326                                 int shaderVersion)
327         : RewritePLSTraverser(compiler, symbolTable, compileOptions, shaderVersion)
328     {}
329 
330   private:
visitPLSDeclaration(TIntermSymbol * plsSymbol)331     void visitPLSDeclaration(TIntermSymbol *plsSymbol) override
332     {
333         // Replace the PLS declaration with an image2D.
334         ensureGlobalPixelCoordDeclared();
335         TVariable *image2D = createPLSImageReplacement(plsSymbol);
336         mImages.insertNew(plsSymbol, image2D);
337         queueReplacement(new TIntermDeclaration({new TIntermSymbol(image2D)}),
338                          OriginalNode::IS_DROPPED);
339     }
340 
341     // Creates an image2D that replaces a pixel local storage handle.
createPLSImageReplacement(const TIntermSymbol * plsSymbol)342     TVariable *createPLSImageReplacement(const TIntermSymbol *plsSymbol)
343     {
344         ASSERT(plsSymbol);
345         ASSERT(IsPixelLocal(plsSymbol->getBasicType()));
346 
347         TType *imageType = new TType(plsSymbol->getType());
348 
349         TLayoutQualifier layoutQualifier = imageType->getLayoutQualifier();
350         switch (layoutQualifier.imageInternalFormat)
351         {
352             case TLayoutImageInternalFormat::EiifRGBA8:
353                 if (!mCompileOptions->pls.supportsNativeRGBA8ImageFormats)
354                 {
355                     layoutQualifier.imageInternalFormat = EiifR32UI;
356                     imageType->setPrecision(EbpHigh);
357                     imageType->setBasicType(EbtUImage2D);
358                 }
359                 else
360                 {
361                     imageType->setBasicType(EbtImage2D);
362                 }
363                 break;
364             case TLayoutImageInternalFormat::EiifRGBA8I:
365                 if (!mCompileOptions->pls.supportsNativeRGBA8ImageFormats)
366                 {
367                     layoutQualifier.imageInternalFormat = EiifR32I;
368                     imageType->setPrecision(EbpHigh);
369                 }
370                 imageType->setBasicType(EbtIImage2D);
371                 break;
372             case TLayoutImageInternalFormat::EiifRGBA8UI:
373                 if (!mCompileOptions->pls.supportsNativeRGBA8ImageFormats)
374                 {
375                     layoutQualifier.imageInternalFormat = EiifR32UI;
376                     imageType->setPrecision(EbpHigh);
377                 }
378                 imageType->setBasicType(EbtUImage2D);
379                 break;
380             case TLayoutImageInternalFormat::EiifR32F:
381                 imageType->setBasicType(EbtImage2D);
382                 break;
383             case TLayoutImageInternalFormat::EiifR32UI:
384                 imageType->setBasicType(EbtUImage2D);
385                 break;
386             default:
387                 UNREACHABLE();
388         }
389         layoutQualifier.rasterOrdered =
390             mCompileOptions->pls.fragmentSyncType ==
391                 ShFragmentSynchronizationType::RasterizerOrderViews_D3D ||
392             mCompileOptions->pls.fragmentSyncType ==
393                 ShFragmentSynchronizationType::RasterOrderGroups_Metal;
394         imageType->setLayoutQualifier(layoutQualifier);
395 
396         TMemoryQualifier memoryQualifier{};
397         memoryQualifier.coherent          = true;
398         memoryQualifier.restrictQualifier = true;
399         memoryQualifier.volatileQualifier = false;
400         // TODO(anglebug.com/40096838): Maybe we could walk the tree first and see which PLS is used
401         // how. If the PLS is never loaded, we could add a writeonly qualifier, for example.
402         memoryQualifier.readonly  = false;
403         memoryQualifier.writeonly = false;
404         imageType->setMemoryQualifier(memoryQualifier);
405 
406         const TVariable &plsVar = plsSymbol->variable();
407         return new TVariable(plsVar.uniqueId(), plsVar.name(), plsVar.symbolType(),
408                              plsVar.extensions(), imageType);
409     }
410 
visitPLSLoad(TIntermSymbol * plsSymbol)411     void visitPLSLoad(TIntermSymbol *plsSymbol) override
412     {
413         // Replace the pixelLocalLoadANGLE with imageLoad.
414         TVariable *image2D = mImages.find(plsSymbol);
415         ASSERT(mGlobalPixelCoord);
416         TIntermTyped *pls = CreateBuiltInFunctionCallNode(
417             "imageLoad", {new TIntermSymbol(image2D), new TIntermSymbol(mGlobalPixelCoord)},
418             *mSymbolTable, 310);
419         pls = unpackImageDataIfNecessary(pls, plsSymbol, image2D);
420         queueReplacement(pls, OriginalNode::IS_DROPPED);
421     }
422 
423     // Unpacks the raw PLS data if the output shader language needs r32* packing.
unpackImageDataIfNecessary(TIntermTyped * data,TIntermSymbol * plsSymbol,TVariable * image2D)424     TIntermTyped *unpackImageDataIfNecessary(TIntermTyped *data,
425                                              TIntermSymbol *plsSymbol,
426                                              TVariable *image2D)
427     {
428         TLayoutImageInternalFormat plsFormat =
429             plsSymbol->getType().getLayoutQualifier().imageInternalFormat;
430         TLayoutImageInternalFormat imageFormat =
431             image2D->getType().getLayoutQualifier().imageInternalFormat;
432         if (plsFormat == imageFormat)
433         {
434             return data;  // This PLS storage isn't packed.
435         }
436         switch (plsFormat)
437         {
438             case EiifRGBA8:
439                 ASSERT(!mCompileOptions->pls.supportsNativeRGBA8ImageFormats);
440                 // Unpack and normalize r,g,b,a from a single 32-bit unsigned int:
441                 //
442                 //     unpackUnorm4x8(data.r)
443                 //
444                 data = CreateBuiltInFunctionCallNode("unpackUnorm4x8", {CreateSwizzle(data, 0)},
445                                                      *mSymbolTable, 310);
446                 break;
447             case EiifRGBA8I:
448             case EiifRGBA8UI:
449             {
450                 ASSERT(!mCompileOptions->pls.supportsNativeRGBA8ImageFormats);
451                 constexpr unsigned shifts[] = {24, 16, 8, 0};
452                 // Unpack r,g,b,a form a single (signed or unsigned) 32-bit int. Shift left,
453                 // then right, to preserve the sign for ints. (highp integers are exactly
454                 // 32-bit, two's compliment.)
455                 //
456                 //     data.rrrr << uvec4(24, 16, 8, 0) >> 24u
457                 //
458                 data = CreateSwizzle(data, 0, 0, 0, 0);
459                 data = new TIntermBinary(EOpBitShiftLeft, data, CreateUVecNode(shifts, 4, EbpLow));
460                 data = new TIntermBinary(EOpBitShiftRight, data, CreateUIntNode(24));
461                 break;
462             }
463             default:
464                 UNREACHABLE();
465         }
466         return data;
467     }
468 
visitPLSStore(TIntermSymbol * plsSymbol,TVariable * value)469     void visitPLSStore(TIntermSymbol *plsSymbol, TVariable *value) override
470     {
471         TVariable *image2D       = mImages.find(plsSymbol);
472         TIntermTyped *packedData = clampAndPackPLSDataIfNecessary(value, plsSymbol, image2D);
473 
474         // Surround the store with memoryBarrierImage calls in order to ensure dependent stores and
475         // loads in a single shader invocation are coherent. From the ES 3.1 spec:
476         //
477         //   Using variables declared as "coherent" guarantees only that the results of stores will
478         //   be immediately visible to shader invocations using similarly-declared variables;
479         //   calling MemoryBarrier is required to ensure that the stores are visible to other
480         //   operations.
481         //
482         insertStatementsInParentBlock(
483             {CreateBuiltInFunctionCallNode("memoryBarrierImage", {}, *mSymbolTable,
484                                            310)},  // Before.
485             {CreateBuiltInFunctionCallNode("memoryBarrierImage", {}, *mSymbolTable,
486                                            310)});  // After.
487 
488         // Rewrite the pixelLocalStoreANGLE with imageStore.
489         ASSERT(mGlobalPixelCoord);
490         queueReplacement(
491             CreateBuiltInFunctionCallNode(
492                 "imageStore",
493                 {new TIntermSymbol(image2D), new TIntermSymbol(mGlobalPixelCoord), packedData},
494                 *mSymbolTable, 310),
495             OriginalNode::IS_DROPPED);
496     }
497 
498     // Packs the PLS to raw data if the output shader language needs r32* packing.
clampAndPackPLSDataIfNecessary(TVariable * plsVar,TIntermSymbol * plsSymbol,TVariable * image2D)499     TIntermTyped *clampAndPackPLSDataIfNecessary(TVariable *plsVar,
500                                                  TIntermSymbol *plsSymbol,
501                                                  TVariable *image2D)
502     {
503         TLayoutImageInternalFormat plsFormat =
504             plsSymbol->getType().getLayoutQualifier().imageInternalFormat;
505         clampPLSVarIfNeeded(plsVar, plsFormat);
506         TIntermTyped *result = new TIntermSymbol(plsVar);
507         TLayoutImageInternalFormat imageFormat =
508             image2D->getType().getLayoutQualifier().imageInternalFormat;
509         if (plsFormat == imageFormat)
510         {
511             return result;  // This PLS storage isn't packed.
512         }
513         switch (plsFormat)
514         {
515             case EiifRGBA8:
516             {
517                 ASSERT(!mCompileOptions->pls.supportsNativeRGBA8ImageFormats);
518                 if (mCompileOptions->passHighpToPackUnormSnormBuiltins)
519                 {
520                     // anglebug.com/42265995: unpackUnorm4x8 doesn't work on Pixel 4 when passed
521                     // a mediump vec4. Use an intermediate highp vec4.
522                     //
523                     // It's safe to inject a variable here because it happens right before
524                     // pixelLocalStoreANGLE, which returns type void. (See visitAggregate.)
525                     TType *highpType              = new TType(EbtFloat, EbpHigh, EvqTemporary, 4);
526                     TVariable *workaroundHighpVar = CreateTempVariable(mSymbolTable, highpType);
527                     insertStatementInParentBlock(
528                         CreateTempInitDeclarationNode(workaroundHighpVar, result));
529                     result = new TIntermSymbol(workaroundHighpVar);
530                 }
531 
532                 // Denormalize and pack r,g,b,a into a single 32-bit unsigned int:
533                 //
534                 //     packUnorm4x8(workaroundHighpVar)
535                 //
536                 result =
537                     CreateBuiltInFunctionCallNode("packUnorm4x8", {result}, *mSymbolTable, 310);
538                 break;
539             }
540             case EiifRGBA8I:
541             case EiifRGBA8UI:
542             {
543                 ASSERT(!mCompileOptions->pls.supportsNativeRGBA8ImageFormats);
544                 if (plsFormat == EiifRGBA8I)
545                 {
546                     // Mask off extra sign bits beyond 8.
547                     //
548                     //     plsVar &= 0xff
549                     //
550                     insertStatementInParentBlock(new TIntermBinary(
551                         EOpBitwiseAndAssign, new TIntermSymbol(plsVar), CreateIndexNode(0xff)));
552                 }
553                 // Pack r,g,b,a into a single 32-bit (signed or unsigned) int:
554                 //
555                 //     r | (g << 8) | (b << 16) | (a << 24)
556                 //
557                 auto shiftComponent = [=](int componentIdx) {
558                     return new TIntermBinary(EOpBitShiftLeft,
559                                              CreateSwizzle(new TIntermSymbol(plsVar), componentIdx),
560                                              CreateUIntNode(componentIdx * 8));
561                 };
562                 result = CreateSwizzle(result, 0);
563                 result = new TIntermBinary(EOpBitwiseOr, result, shiftComponent(1));
564                 result = new TIntermBinary(EOpBitwiseOr, result, shiftComponent(2));
565                 result = new TIntermBinary(EOpBitwiseOr, result, shiftComponent(3));
566                 break;
567             }
568             default:
569                 UNREACHABLE();
570         }
571         // Convert the packed data to a {u,i}vec4 for imageStore.
572         TType imageStoreType(DataTypeOfImageType(image2D->getType().getBasicType()), 4);
573         return TIntermAggregate::CreateConstructor(imageStoreType, {result});
574     }
575 
injectPrePLSCode(TCompiler * compiler,TSymbolTable & symbolTable,const ShCompileOptions & compileOptions,TIntermBlock * mainBody,size_t plsBeginPosition)576     void injectPrePLSCode(TCompiler *compiler,
577                           TSymbolTable &symbolTable,
578                           const ShCompileOptions &compileOptions,
579                           TIntermBlock *mainBody,
580                           size_t plsBeginPosition) override
581     {
582         // When PLS is implemented with images, early_fragment_tests ensure that depth/stencil
583         // can also block stores to PLS.
584         compiler->specifyEarlyFragmentTests();
585 
586         // Delimit the beginning of a per-pixel critical section, if supported. This makes pixel
587         // local storage coherent.
588         //
589         // Either: GL_NV_fragment_shader_interlock
590         //         GL_INTEL_fragment_shader_ordering
591         //         GL_ARB_fragment_shader_interlock (may compile to
592         //                                           SPV_EXT_fragment_shader_interlock)
593         switch (compileOptions.pls.fragmentSyncType)
594         {
595             // Raster ordered resources don't need explicit synchronization calls.
596             case ShFragmentSynchronizationType::RasterizerOrderViews_D3D:
597             case ShFragmentSynchronizationType::RasterOrderGroups_Metal:
598             case ShFragmentSynchronizationType::NotSupported:
599                 break;
600             case ShFragmentSynchronizationType::FragmentShaderInterlock_NV_GL:
601                 mainBody->insertStatement(
602                     plsBeginPosition,
603                     CreateBuiltInFunctionCallNode("beginInvocationInterlockNV", {}, symbolTable,
604                                                   kESSLInternalBackendBuiltIns));
605                 break;
606             case ShFragmentSynchronizationType::FragmentShaderOrdering_INTEL_GL:
607                 mainBody->insertStatement(
608                     plsBeginPosition,
609                     CreateBuiltInFunctionCallNode("beginFragmentShaderOrderingINTEL", {},
610                                                   symbolTable, kESSLInternalBackendBuiltIns));
611                 break;
612             case ShFragmentSynchronizationType::FragmentShaderInterlock_ARB_GL:
613                 mainBody->insertStatement(
614                     plsBeginPosition,
615                     CreateBuiltInFunctionCallNode("beginInvocationInterlockARB", {}, symbolTable,
616                                                   kESSLInternalBackendBuiltIns));
617                 break;
618             default:
619                 UNREACHABLE();
620         }
621     }
622 
injectPostPLSCode(TCompiler *,TSymbolTable & symbolTable,const ShCompileOptions & compileOptions,TIntermBlock * mainBody,size_t plsEndPosition)623     void injectPostPLSCode(TCompiler *,
624                            TSymbolTable &symbolTable,
625                            const ShCompileOptions &compileOptions,
626                            TIntermBlock *mainBody,
627                            size_t plsEndPosition) override
628     {
629         // Delimit the end of the PLS critical section, if required.
630         //
631         // Either: GL_NV_fragment_shader_interlock
632         //         GL_ARB_fragment_shader_interlock (may compile to
633         //                                           SPV_EXT_fragment_shader_interlock)
634         switch (compileOptions.pls.fragmentSyncType)
635         {
636             // Raster ordered resources don't need explicit synchronization calls.
637             case ShFragmentSynchronizationType::RasterizerOrderViews_D3D:
638             case ShFragmentSynchronizationType::RasterOrderGroups_Metal:
639             // GL_INTEL_fragment_shader_ordering doesn't have an "end()" call.
640             case ShFragmentSynchronizationType::FragmentShaderOrdering_INTEL_GL:
641             case ShFragmentSynchronizationType::NotSupported:
642                 break;
643             case ShFragmentSynchronizationType::FragmentShaderInterlock_NV_GL:
644 
645                 mainBody->insertStatement(
646                     plsEndPosition,
647                     CreateBuiltInFunctionCallNode("endInvocationInterlockNV", {}, symbolTable,
648                                                   kESSLInternalBackendBuiltIns));
649                 break;
650             case ShFragmentSynchronizationType::FragmentShaderInterlock_ARB_GL:
651                 mainBody->insertStatement(
652                     plsEndPosition,
653                     CreateBuiltInFunctionCallNode("endInvocationInterlockARB", {}, symbolTable,
654                                                   kESSLInternalBackendBuiltIns));
655                 break;
656             default:
657                 UNREACHABLE();
658         }
659     }
660 
661     PLSBackingStoreMap<TVariable *> mImages;
662 };
663 
664 // Rewrites high level PLS operations to framebuffer fetch operations.
665 class RewritePLSToFramebufferFetchTraverser : public RewritePLSTraverser
666 {
667   public:
RewritePLSToFramebufferFetchTraverser(TCompiler * compiler,TSymbolTable & symbolTable,const ShCompileOptions & compileOptions,int shaderVersion)668     RewritePLSToFramebufferFetchTraverser(TCompiler *compiler,
669                                           TSymbolTable &symbolTable,
670                                           const ShCompileOptions &compileOptions,
671                                           int shaderVersion)
672         : RewritePLSTraverser(compiler, symbolTable, compileOptions, shaderVersion)
673     {}
674 
visitPLSDeclaration(TIntermSymbol * plsSymbol)675     void visitPLSDeclaration(TIntermSymbol *plsSymbol) override
676     {
677         // Replace the PLS declaration with a framebuffer attachment.
678         PLSAttachment attachment(mCompiler, mSymbolTable, *mCompileOptions, plsSymbol->variable());
679         mPLSAttachments.insertNew(plsSymbol, attachment);
680         insertStatementInParentBlock(
681             new TIntermDeclaration({new TIntermSymbol(attachment.fragmentVar)}));
682         queueReplacement(CreateTempDeclarationNode(attachment.accessVar), OriginalNode::IS_DROPPED);
683     }
684 
visitPLSLoad(TIntermSymbol * plsSymbol)685     void visitPLSLoad(TIntermSymbol *plsSymbol) override
686     {
687         // Read our temporary accessVar.
688         const PLSAttachment &attachment = mPLSAttachments.find(plsSymbol);
689         queueReplacement(Expand(attachment.accessVar), OriginalNode::IS_DROPPED);
690     }
691 
visitPLSStore(TIntermSymbol * plsSymbol,TVariable * value)692     void visitPLSStore(TIntermSymbol *plsSymbol, TVariable *value) override
693     {
694         // Set our temporary accessVar.
695         const PLSAttachment &attachment = mPLSAttachments.find(plsSymbol);
696         queueReplacement(CreateTempAssignmentNode(attachment.accessVar, attachment.swizzle(value)),
697                          OriginalNode::IS_DROPPED);
698     }
699 
injectPrePLSCode(TCompiler * compiler,TSymbolTable & symbolTable,const ShCompileOptions & compileOptions,TIntermBlock * mainBody,size_t plsBeginPosition)700     void injectPrePLSCode(TCompiler *compiler,
701                           TSymbolTable &symbolTable,
702                           const ShCompileOptions &compileOptions,
703                           TIntermBlock *mainBody,
704                           size_t plsBeginPosition) override
705     {
706         // [OpenGL ES Version 3.0.6, 3.9.2.3 "Shader Output"]: Any colors, or color components,
707         // associated with a fragment that are not written by the fragment shader are undefined.
708         //
709         // [EXT_shader_framebuffer_fetch]: Prior to fragment shading, fragment outputs declared
710         // inout are populated with the value last written to the framebuffer at the same(x, y,
711         // sample) position.
712         //
713         // It's unclear from the EXT_shader_framebuffer_fetch spec whether inout fragment variables
714         // become undefined if not explicitly written, but either way, when this compiles to subpass
715         // loads in Vulkan, we definitely get undefined behavior if PLS variables are not written.
716         //
717         // To make sure every PLS variable gets written, we read them all before PLS operations,
718         // then write them all back out after all PLS is complete.
719         TIntermSequence plsPreloads;
720         plsPreloads.reserve(mPLSAttachments.bindingOrderedMap().size());
721         for (const auto &entry : mPLSAttachments.bindingOrderedMap())
722         {
723             const PLSAttachment &attachment = entry.second;
724             plsPreloads.push_back(
725                 CreateTempAssignmentNode(attachment.accessVar, attachment.swizzleFragmentVar()));
726         }
727         mainBody->getSequence()->insert(mainBody->getSequence()->begin() + plsBeginPosition,
728                                         plsPreloads.begin(), plsPreloads.end());
729     }
730 
injectPostPLSCode(TCompiler *,TSymbolTable & symbolTable,const ShCompileOptions & compileOptions,TIntermBlock * mainBody,size_t plsEndPosition)731     void injectPostPLSCode(TCompiler *,
732                            TSymbolTable &symbolTable,
733                            const ShCompileOptions &compileOptions,
734                            TIntermBlock *mainBody,
735                            size_t plsEndPosition) override
736     {
737         TIntermSequence plsWrites;
738         plsWrites.reserve(mPLSAttachments.bindingOrderedMap().size());
739         for (const auto &entry : mPLSAttachments.bindingOrderedMap())
740         {
741             const PLSAttachment &attachment = entry.second;
742             plsWrites.push_back(new TIntermBinary(EOpAssign, attachment.swizzleFragmentVar(),
743                                                   new TIntermSymbol(attachment.accessVar)));
744         }
745         mainBody->getSequence()->insert(mainBody->getSequence()->begin() + plsEndPosition,
746                                         plsWrites.begin(), plsWrites.end());
747     }
748 
749   private:
750     struct PLSAttachment
751     {
PLSAttachmentsh::__anonef336ebc0111::RewritePLSToFramebufferFetchTraverser::PLSAttachment752         PLSAttachment(const TCompiler *compiler,
753                       TSymbolTable *symbolTable,
754                       const ShCompileOptions &compileOptions,
755                       const TVariable &plsVar)
756         {
757             const TType &plsType = plsVar.getType();
758 
759             TType *accessVarType;
760             switch (plsType.getLayoutQualifier().imageInternalFormat)
761             {
762                 default:
763                     UNREACHABLE();
764                     [[fallthrough]];
765                 case EiifRGBA8:
766                     accessVarType = new TType(EbtFloat, 4);
767                     break;
768                 case EiifRGBA8I:
769                     accessVarType = new TType(EbtInt, 4);
770                     break;
771                 case EiifRGBA8UI:
772                     accessVarType = new TType(EbtUInt, 4);
773                     break;
774                 case EiifR32F:
775                     accessVarType = new TType(EbtFloat, 1);
776                     break;
777                 case EiifR32UI:
778                     accessVarType = new TType(EbtUInt, 1);
779                     break;
780             }
781             accessVarType->setPrecision(plsType.getPrecision());
782             accessVar = CreateTempVariable(symbolTable, accessVarType);
783 
784             // Qualcomm seems to want fragment outputs to be 4-component vectors, and produces a
785             // compile error from "inout uint". Our Metal translator also saturates color outputs to
786             // 4 components. And since the spec also seems silent on how many components an output
787             // must have, we always use 4.
788             TType *fragmentVarType = new TType(accessVarType->getBasicType(), 4);
789             fragmentVarType->setPrecision(plsType.getPrecision());
790             fragmentVarType->setQualifier(EvqFragmentInOut);
791 
792             // PLS attachments are bound in reverse order from the rear.
793             TLayoutQualifier layoutQualifier = TLayoutQualifier::Create();
794             layoutQualifier.location =
795                 compiler->getResources().MaxCombinedDrawBuffersAndPixelLocalStoragePlanes -
796                 plsType.getLayoutQualifier().binding - 1;
797             layoutQualifier.locationsSpecified = 1;
798             if (compileOptions.pls.fragmentSyncType == ShFragmentSynchronizationType::NotSupported)
799             {
800                 // We're using EXT_shader_framebuffer_fetch_non_coherent, which requires the
801                 // "noncoherent" qualifier.
802                 layoutQualifier.noncoherent = true;
803             }
804             fragmentVarType->setLayoutQualifier(layoutQualifier);
805 
806             fragmentVar = new TVariable(plsVar.uniqueId(), plsVar.name(), plsVar.symbolType(),
807                                         plsVar.extensions(), fragmentVarType);
808         }
809 
810         // Swizzles a variable down to the same number of components as the PLS internalformat.
swizzlesh::__anonef336ebc0111::RewritePLSToFramebufferFetchTraverser::PLSAttachment811         TIntermTyped *swizzle(TVariable *var) const
812         {
813             return Swizzle(var, accessVar->getType().getNominalSize());
814         }
815 
swizzleFragmentVarsh::__anonef336ebc0111::RewritePLSToFramebufferFetchTraverser::PLSAttachment816         TIntermTyped *swizzleFragmentVar() const { return swizzle(fragmentVar); }
817 
818         TVariable *fragmentVar;
819         TVariable *accessVar;
820     };
821 
822     PLSBackingStoreMap<PLSAttachment> mPLSAttachments;
823 };
824 
825 }  // anonymous namespace
826 
RewritePixelLocalStorage(TCompiler * compiler,TIntermBlock * root,TSymbolTable & symbolTable,const ShCompileOptions & compileOptions,int shaderVersion)827 bool RewritePixelLocalStorage(TCompiler *compiler,
828                               TIntermBlock *root,
829                               TSymbolTable &symbolTable,
830                               const ShCompileOptions &compileOptions,
831                               int shaderVersion)
832 {
833     // If any functions take PLS arguments, monomorphize the functions by removing said parameters
834     // and making the PLS calls from main() instead, using the global uniform from the call site
835     // instead of the function argument. This is necessary because function arguments don't carry
836     // the necessary "binding" or "format" layout qualifiers.
837     if (!MonomorphizeUnsupportedFunctions(
838             compiler, root, &symbolTable,
839             UnsupportedFunctionArgsBitSet{UnsupportedFunctionArgs::PixelLocalStorage}))
840     {
841         return false;
842     }
843 
844     TIntermBlock *mainBody = FindMainBody(root);
845 
846     std::unique_ptr<RewritePLSTraverser> traverser;
847     switch (compileOptions.pls.type)
848     {
849         case ShPixelLocalStorageType::ImageLoadStore:
850             traverser = std::make_unique<RewritePLSToImagesTraverser>(
851                 compiler, symbolTable, compileOptions, shaderVersion);
852             break;
853         case ShPixelLocalStorageType::FramebufferFetch:
854             traverser = std::make_unique<RewritePLSToFramebufferFetchTraverser>(
855                 compiler, symbolTable, compileOptions, shaderVersion);
856             break;
857         case ShPixelLocalStorageType::NotSupported:
858             UNREACHABLE();
859             return false;
860     }
861 
862     // Rewrite PLS operations.
863     root->traverse(traverser.get());
864     if (!traverser->updateTree(compiler, root))
865     {
866         return false;
867     }
868 
869     // Inject the code that needs to run before and after all PLS operations.
870     // TODO(anglebug.com/40096838): Inject these functions in a tight critical section, instead of
871     // just locking the entire main() function:
872     //   - Monomorphize all PLS calls into main().
873     //   - Insert begin/end calls around the first/last PLS calls (and outside of flow control).
874     traverser->injectPrePLSCode(compiler, symbolTable, compileOptions, mainBody, 0);
875     traverser->injectPostPLSCode(compiler, symbolTable, compileOptions, mainBody,
876                                  mainBody->getChildCount());
877 
878     // Assign the global pixel coord at the beginning of main(), if used.
879     traverser->injectPixelCoordInitializationCodeIfNeeded(compiler, root, mainBody);
880 
881     return compiler->validateAST(root);
882 }
883 }  // namespace sh
884