xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/spirv/EmulateAdvancedBlendEquations.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2022 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // EmulateAdvancedBlendEquations.cpp: Emulate advanced blend equations by implicitly reading back
7 // from the color attachment (as an input attachment) and apply the equation function based on a
8 // uniform.
9 //
10 
11 #include "compiler/translator/tree_ops/spirv/EmulateAdvancedBlendEquations.h"
12 
13 #include <map>
14 
15 #include "GLSLANG/ShaderVars.h"
16 #include "common/PackedEnums.h"
17 #include "compiler/translator/Compiler.h"
18 #include "compiler/translator/StaticType.h"
19 #include "compiler/translator/SymbolTable.h"
20 #include "compiler/translator/tree_util/DriverUniform.h"
21 #include "compiler/translator/tree_util/FindMain.h"
22 #include "compiler/translator/tree_util/IntermNode_util.h"
23 #include "compiler/translator/tree_util/IntermTraverse.h"
24 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
25 
26 namespace sh
27 {
28 namespace
29 {
30 
31 // All helper functions that may be generated.
32 class Builder
33 {
34   public:
Builder(TCompiler * compiler,TSymbolTable * symbolTable,const AdvancedBlendEquations & advancedBlendEquations,const DriverUniform * driverUniforms,InputAttachmentMap * inputAttachmentMap)35     Builder(TCompiler *compiler,
36             TSymbolTable *symbolTable,
37             const AdvancedBlendEquations &advancedBlendEquations,
38             const DriverUniform *driverUniforms,
39             InputAttachmentMap *inputAttachmentMap)
40         : mCompiler(compiler),
41           mSymbolTable(symbolTable),
42           mDriverUniforms(driverUniforms),
43           mInputAttachmentMap(inputAttachmentMap),
44           mAdvancedBlendEquations(advancedBlendEquations)
45     {}
46 
47     bool build(TIntermBlock *root);
48 
49   private:
50     void findColorOutput(TIntermBlock *root);
51     void createSubpassInputVar(TIntermBlock *root);
52     void generateHslHelperFunctions();
53     void generateBlendFunctions();
54     void insertGeneratedFunctions(TIntermBlock *root);
55     TIntermTyped *divideFloatNode(TIntermTyped *dividend, TIntermTyped *divisor);
56     TIntermSymbol *premultiplyAlpha(TIntermBlock *blendBlock, TIntermTyped *var, const char *name);
57     void generatePreamble(TIntermBlock *blendBlock);
58     void generateEquationSwitch(TIntermBlock *blendBlock);
59 
60     TCompiler *mCompiler;
61     TSymbolTable *mSymbolTable;
62     const DriverUniform *mDriverUniforms;
63     InputAttachmentMap *mInputAttachmentMap;
64     const AdvancedBlendEquations &mAdvancedBlendEquations;
65 
66     // The color input and output.  Output is the blend source, and input is the destination.
67     const TVariable *mSubpassInputVar = nullptr;
68     const TVariable *mOutputVar       = nullptr;
69 
70     // The value of output, premultiplied by alpha
71     TIntermSymbol *mSrc = nullptr;
72     // The value of input, premultiplied by alpha
73     TIntermSymbol *mDst = nullptr;
74 
75     // p0, p1 and p2 coefficients
76     TIntermSymbol *mP0 = nullptr;
77     TIntermSymbol *mP1 = nullptr;
78     TIntermSymbol *mP2 = nullptr;
79 
80     // Functions implementing an advanced blend equation:
81     angle::PackedEnumMap<gl::BlendEquationType, TIntermFunctionDefinition *> mBlendFuncs = {};
82 
83     // HSL helpers:
84     TIntermFunctionDefinition *mMinv3     = nullptr;
85     TIntermFunctionDefinition *mMaxv3     = nullptr;
86     TIntermFunctionDefinition *mLumv3     = nullptr;
87     TIntermFunctionDefinition *mSatv3     = nullptr;
88     TIntermFunctionDefinition *mClipColor = nullptr;
89     TIntermFunctionDefinition *mSetLum    = nullptr;
90     TIntermFunctionDefinition *mSetLumSat = nullptr;
91 };
92 
build(TIntermBlock * root)93 bool Builder::build(TIntermBlock *root)
94 {
95     // Find the output variable for which advanced blend is specified.  Note that advanced blend can
96     // only be used when rendering is done to a single color attachment.
97     findColorOutput(root);
98     if (mSubpassInputVar == nullptr)
99     {
100         createSubpassInputVar(root);
101     }
102 
103     // If any HSL blend equation is used, generate a few utility functions used in Table X.2 in the
104     // spec.
105     if (mAdvancedBlendEquations.anyHsl())
106     {
107         generateHslHelperFunctions();
108     }
109 
110     // Generate a function for each enabled blend equation.  This is |f| in the spec.
111     generateBlendFunctions();
112 
113     // Insert the generated functions to root.
114     insertGeneratedFunctions(root);
115 
116     // Prepare for blend by:
117     //
118     // - Premultiplying src and dst color by alpha
119     // - Calculating p0, p1 and p2
120     //
121     // Note that the color coefficients (X,Y,Z) are always (1,1,1) in the KHR extension (they were
122     // not in the NV extension), so they are implicitly dropped.
123     TIntermBlock *blendBlock = new TIntermBlock;
124     generatePreamble(blendBlock);
125 
126     // Generate the |switch| that calls the right function based on a driver uniform.
127     generateEquationSwitch(blendBlock);
128 
129     // Place the entire blend block under an if (equation != 0)
130     TIntermTyped *equationUniform = mDriverUniforms->getAdvancedBlendEquation();
131     TIntermTyped *notZero = new TIntermBinary(EOpNotEqual, equationUniform, CreateUIntNode(0));
132 
133     TIntermIfElse *blend = new TIntermIfElse(notZero, blendBlock, nullptr);
134     return RunAtTheEndOfShader(mCompiler, root, blend, mSymbolTable);
135 }
136 
findColorOutput(TIntermBlock * root)137 void Builder::findColorOutput(TIntermBlock *root)
138 {
139     for (TIntermNode *node : *root->getSequence())
140     {
141         TIntermDeclaration *asDecl = node->getAsDeclarationNode();
142         if (asDecl == nullptr)
143         {
144             continue;
145         }
146 
147         // SeparateDeclarations should have already been run.
148         ASSERT(asDecl->getSequence()->size() == 1u);
149 
150         TIntermSymbol *symbol = asDecl->getSequence()->front()->getAsSymbolNode();
151         if (symbol == nullptr)
152         {
153             continue;
154         }
155 
156         const TType &type = symbol->getType();
157         if (type.getQualifier() == EvqFragmentOut || type.getQualifier() == EvqFragmentInOut)
158         {
159             // There can only be one output with advanced blend per spec.
160             // If there are multiple outputs, take the one one with location 0.
161             if (mOutputVar == nullptr || mOutputVar->getType().getLayoutQualifier().location > 0)
162             {
163                 mOutputVar = &symbol->variable();
164             }
165         }
166 
167         if (IsSubpassInputType(type.getBasicType()) &&
168             symbol->getName() != "ANGLEDepthInputAttachment" &&
169             symbol->getName() != "ANGLEStencilInputAttachment")
170         {
171             // There can only be one output with advanced blend, so there can only be a maximum of
172             // one subpass input already defined (by framebuffer fetch emulation).
173             ASSERT(mSubpassInputVar == nullptr);
174             mSubpassInputVar = &symbol->variable();
175         }
176     }
177 
178     // This transformation is only ever called when advanced blend is specified.
179     ASSERT(mOutputVar != nullptr);
180 }
181 
MakeVariable(TSymbolTable * symbolTable,const char * name,const TType * type)182 TIntermSymbol *MakeVariable(TSymbolTable *symbolTable, const char *name, const TType *type)
183 {
184     const TVariable *var =
185         new TVariable(symbolTable, ImmutableString(name), type, SymbolType::AngleInternal);
186     return new TIntermSymbol(var);
187 }
188 
createSubpassInputVar(TIntermBlock * root)189 void Builder::createSubpassInputVar(TIntermBlock *root)
190 {
191     const TPrecision precision = mOutputVar->getType().getPrecision();
192 
193     // The input attachment index used for this color attachment would be identical to its location
194     // (or implicitly 0 if not specified).
195     const unsigned int inputAttachmentIndex =
196         std::max(0, mOutputVar->getType().getLayoutQualifier().location);
197 
198     // Note that blending can only happen on float/fixed-point output.
199     ASSERT(mOutputVar->getType().getBasicType() == EbtFloat);
200 
201     // Create the subpass input uniform.
202     TType *inputAttachmentType = new TType(EbtSubpassInput, precision, EvqUniform, 1);
203     TLayoutQualifier inputAttachmentQualifier     = inputAttachmentType->getLayoutQualifier();
204     inputAttachmentQualifier.inputAttachmentIndex = inputAttachmentIndex;
205     inputAttachmentType->setLayoutQualifier(inputAttachmentQualifier);
206 
207     const char *kSubpassInputName = "ANGLEFragmentInput";
208     TIntermSymbol *subpassInputSymbol =
209         MakeVariable(mSymbolTable, kSubpassInputName, inputAttachmentType);
210     mSubpassInputVar = &subpassInputSymbol->variable();
211 
212     // Add its declaration to the shader.
213     TIntermDeclaration *subpassInputDecl = new TIntermDeclaration;
214     subpassInputDecl->appendDeclarator(subpassInputSymbol);
215     root->insertStatement(0, subpassInputDecl);
216 
217     mInputAttachmentMap->color[inputAttachmentIndex] = mSubpassInputVar;
218 }
219 
Float(float f)220 TIntermTyped *Float(float f)
221 {
222     return CreateFloatNode(f, EbpMedium);
223 }
224 
MakeFunction(TSymbolTable * symbolTable,const char * name,const TType * returnType,const TVector<const TVariable * > & args)225 TFunction *MakeFunction(TSymbolTable *symbolTable,
226                         const char *name,
227                         const TType *returnType,
228                         const TVector<const TVariable *> &args)
229 {
230     TFunction *function = new TFunction(symbolTable, ImmutableString(name),
231                                         SymbolType::AngleInternal, returnType, false);
232     for (const TVariable *arg : args)
233     {
234         function->addParameter(arg);
235     }
236     return function;
237 }
238 
MakeFunctionDefinition(const TFunction * function,TIntermBlock * body)239 TIntermFunctionDefinition *MakeFunctionDefinition(const TFunction *function, TIntermBlock *body)
240 {
241     return new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
242 }
243 
MakeSimpleFunctionDefinition(TSymbolTable * symbolTable,const char * name,TIntermTyped * returnExpression,const TVector<TIntermSymbol * > & args)244 TIntermFunctionDefinition *MakeSimpleFunctionDefinition(TSymbolTable *symbolTable,
245                                                         const char *name,
246                                                         TIntermTyped *returnExpression,
247                                                         const TVector<TIntermSymbol *> &args)
248 {
249     TVector<const TVariable *> argsAsVar;
250     for (TIntermSymbol *arg : args)
251     {
252         argsAsVar.push_back(&arg->variable());
253     }
254 
255     TIntermBlock *body = new TIntermBlock;
256     body->appendStatement(new TIntermBranch(EOpReturn, returnExpression));
257 
258     const TFunction *function =
259         MakeFunction(symbolTable, name, &returnExpression->getType(), argsAsVar);
260     return MakeFunctionDefinition(function, body);
261 }
262 
generateHslHelperFunctions()263 void Builder::generateHslHelperFunctions()
264 {
265     const TPrecision precision = mOutputVar->getType().getPrecision();
266 
267     TType *floatType     = new TType(EbtFloat, precision, EvqTemporary, 1);
268     TType *vec3Type      = new TType(EbtFloat, precision, EvqTemporary, 3);
269     TType *vec3ParamType = new TType(EbtFloat, precision, EvqParamIn, 3);
270 
271     // float ANGLE_minv3(vec3 c)
272     // {
273     //     return min(min(c.r, c.g), c.b);
274     // }
275     {
276         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
277 
278         TIntermTyped *cR = new TIntermSwizzle(c, {0});
279         TIntermTyped *cG = new TIntermSwizzle(c->deepCopy(), {1});
280         TIntermTyped *cB = new TIntermSwizzle(c->deepCopy(), {2});
281 
282         // min(c.r, c.g)
283         TIntermSequence cRcG = {cR, cG};
284         TIntermTyped *minRG  = CreateBuiltInFunctionCallNode("min", &cRcG, *mSymbolTable, 100);
285 
286         // min(min(c.r, c.g), c.b)
287         TIntermSequence minRGcB = {minRG, cB};
288         TIntermTyped *minRGB = CreateBuiltInFunctionCallNode("min", &minRGcB, *mSymbolTable, 100);
289 
290         mMinv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_minv3", minRGB, {c});
291     }
292 
293     // float ANGLE_maxv3(vec3 c)
294     // {
295     //     return max(max(c.r, c.g), c.b);
296     // }
297     {
298         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
299 
300         TIntermTyped *cR = new TIntermSwizzle(c, {0});
301         TIntermTyped *cG = new TIntermSwizzle(c->deepCopy(), {1});
302         TIntermTyped *cB = new TIntermSwizzle(c->deepCopy(), {2});
303 
304         // max(c.r, c.g)
305         TIntermSequence cRcG = {cR, cG};
306         TIntermTyped *maxRG  = CreateBuiltInFunctionCallNode("max", &cRcG, *mSymbolTable, 100);
307 
308         // max(max(c.r, c.g), c.b)
309         TIntermSequence maxRGcB = {maxRG, cB};
310         TIntermTyped *maxRGB = CreateBuiltInFunctionCallNode("max", &maxRGcB, *mSymbolTable, 100);
311 
312         mMaxv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_maxv3", maxRGB, {c});
313     }
314 
315     // float ANGLE_lumv3(vec3 c)
316     // {
317     //     return dot(c, vec3(0.30f, 0.59f, 0.11f));
318     // }
319     {
320         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
321 
322         constexpr std::array<float, 3> kCoeff = {0.30f, 0.59f, 0.11f};
323         TIntermConstantUnion *coeff           = CreateVecNode(kCoeff.data(), 3, EbpMedium);
324 
325         // dot(c, coeff)
326         TIntermSequence cCoeff = {c, coeff};
327         TIntermTyped *dot      = CreateBuiltInFunctionCallNode("dot", &cCoeff, *mSymbolTable, 100);
328 
329         mLumv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_lumv3", dot, {c});
330     }
331 
332     // float ANGLE_satv3(vec3 c)
333     // {
334     //     return ANGLE_maxv3(c) - ANGLE_minv3(c);
335     // }
336     {
337         TIntermSymbol *c = MakeVariable(mSymbolTable, "c", vec3ParamType);
338 
339         // ANGLE_maxv3(c)
340         TIntermSequence cMaxArg = {c};
341         TIntermTyped *maxv3 =
342             TIntermAggregate::CreateFunctionCall(*mMaxv3->getFunction(), &cMaxArg);
343 
344         // ANGLE_minv3(c)
345         TIntermSequence cMinArg = {c->deepCopy()};
346         TIntermTyped *minv3 =
347             TIntermAggregate::CreateFunctionCall(*mMinv3->getFunction(), &cMinArg);
348 
349         // max - min
350         TIntermTyped *diff = new TIntermBinary(EOpSub, maxv3, minv3);
351 
352         mSatv3 = MakeSimpleFunctionDefinition(mSymbolTable, "ANGLE_satv3", diff, {c});
353     }
354 
355     // vec3 ANGLE_clip_color(vec3 color)
356     // {
357     //     float lum = ANGLE_lumv3(color);
358     //     float mincol = ANGLE_minv3(color);
359     //     float maxcol = ANGLE_maxv3(color);
360     //     if (mincol < 0.0f)
361     //     {
362     //         color = lum + ((color - lum) * lum) / (lum - mincol);
363     //     }
364     //     if (maxcol > 1.0f)
365     //     {
366     //         color = lum + ((color - lum) * (1.0f - lum)) / (maxcol - lum);
367     //     }
368     //     return color;
369     // }
370     {
371         TIntermSymbol *color  = MakeVariable(mSymbolTable, "color", vec3ParamType);
372         TIntermSymbol *lum    = MakeVariable(mSymbolTable, "lum", floatType);
373         TIntermSymbol *mincol = MakeVariable(mSymbolTable, "mincol", floatType);
374         TIntermSymbol *maxcol = MakeVariable(mSymbolTable, "maxcol", floatType);
375 
376         // ANGLE_lumv3(color)
377         TIntermSequence cLumArg = {color};
378         TIntermTyped *lumv3 =
379             TIntermAggregate::CreateFunctionCall(*mLumv3->getFunction(), &cLumArg);
380 
381         // ANGLE_minv3(color)
382         TIntermSequence cMinArg = {color->deepCopy()};
383         TIntermTyped *minv3 =
384             TIntermAggregate::CreateFunctionCall(*mMinv3->getFunction(), &cMinArg);
385 
386         // ANGLE_maxv3(color)
387         TIntermSequence cMaxArg = {color->deepCopy()};
388         TIntermTyped *maxv3 =
389             TIntermAggregate::CreateFunctionCall(*mMaxv3->getFunction(), &cMaxArg);
390 
391         TIntermBlock *body = new TIntermBlock;
392         body->appendStatement(CreateTempInitDeclarationNode(&lum->variable(), lumv3));
393         body->appendStatement(CreateTempInitDeclarationNode(&mincol->variable(), minv3));
394         body->appendStatement(CreateTempInitDeclarationNode(&maxcol->variable(), maxv3));
395 
396         // color - lum
397         TIntermTyped *colorMinusLum = new TIntermBinary(EOpSub, color->deepCopy(), lum);
398         // (color - lum) * lum
399         TIntermTyped *colorMinusLumTimesLum =
400             new TIntermBinary(EOpVectorTimesScalar, colorMinusLum, lum->deepCopy());
401         // lum - mincol
402         TIntermTyped *lumMinusMincol = new TIntermBinary(EOpSub, lum->deepCopy(), mincol);
403         // ((color - lum) * lum) / (lum - mincol)
404         TIntermTyped *negativeMincolLumOffset =
405             new TIntermBinary(EOpDiv, colorMinusLumTimesLum, lumMinusMincol);
406         // lum + ((color - lum) * lum) / (lum - mincol)
407         TIntermTyped *negativeMincolOffset =
408             new TIntermBinary(EOpAdd, lum->deepCopy(), negativeMincolLumOffset);
409         // color = lum + ((color - lum) * lum) / (lum - mincol)
410         TIntermBlock *if1Body = new TIntermBlock;
411         if1Body->appendStatement(
412             new TIntermBinary(EOpAssign, color->deepCopy(), negativeMincolOffset));
413 
414         // mincol < 0.0f
415         TIntermTyped *lessZero = new TIntermBinary(EOpLessThan, mincol->deepCopy(), Float(0));
416         // if (mincol < 0.0f) ...
417         body->appendStatement(new TIntermIfElse(lessZero, if1Body, nullptr));
418 
419         // 1.0f - lum
420         TIntermTyped *oneMinusLum = new TIntermBinary(EOpSub, Float(1.0f), lum->deepCopy());
421         // (color - lum) * (1.0f - lum)
422         TIntermTyped *colorMinusLumTimesOneMinusLum =
423             new TIntermBinary(EOpVectorTimesScalar, colorMinusLum->deepCopy(), oneMinusLum);
424         // maxcol - lum
425         TIntermTyped *maxcolMinusLum = new TIntermBinary(EOpSub, maxcol, lum->deepCopy());
426         // (color - lum) * (1.0f - lum) / (maxcol - lum)
427         TIntermTyped *largeMaxcolLumOffset =
428             new TIntermBinary(EOpDiv, colorMinusLumTimesOneMinusLum, maxcolMinusLum);
429         // lum + (color - lum) * (1.0f - lum) / (maxcol - lum)
430         TIntermTyped *largeMaxcolOffset =
431             new TIntermBinary(EOpAdd, lum->deepCopy(), largeMaxcolLumOffset);
432         // color = lum + (color - lum) * (1.0f - lum) / (maxcol - lum)
433         TIntermBlock *if2Body = new TIntermBlock;
434         if2Body->appendStatement(
435             new TIntermBinary(EOpAssign, color->deepCopy(), largeMaxcolOffset));
436 
437         // maxcol > 1.0f
438         TIntermTyped *largerOne = new TIntermBinary(EOpGreaterThan, maxcol->deepCopy(), Float(1));
439         // if (maxcol > 1.0f) ...
440         body->appendStatement(new TIntermIfElse(largerOne, if2Body, nullptr));
441 
442         body->appendStatement(new TIntermBranch(EOpReturn, color->deepCopy()));
443 
444         const TFunction *function =
445             MakeFunction(mSymbolTable, "ANGLE_clip_color", vec3Type, {&color->variable()});
446         mClipColor = MakeFunctionDefinition(function, body);
447     }
448 
449     // vec3 ANGLE_set_lum(vec3 cbase, vec3 clum)
450     // {
451     //     float lbase = ANGLE_lumv3(cbase);
452     //     float llum = ANGLE_lumv3(clum);
453     //     float ldiff = llum - lbase;
454     //     vec3 color = cbase + ldiff;
455     //     return ANGLE_clip_color(color);
456     // }
457     {
458         TIntermSymbol *cbase = MakeVariable(mSymbolTable, "cbase", vec3ParamType);
459         TIntermSymbol *clum  = MakeVariable(mSymbolTable, "clum", vec3ParamType);
460 
461         // ANGLE_lumv3(cbase)
462         TIntermSequence cbaseArg = {cbase};
463         TIntermTyped *lbase =
464             TIntermAggregate::CreateFunctionCall(*mLumv3->getFunction(), &cbaseArg);
465 
466         // ANGLE_lumv3(clum)
467         TIntermSequence clumArg = {clum};
468         TIntermTyped *llum = TIntermAggregate::CreateFunctionCall(*mLumv3->getFunction(), &clumArg);
469 
470         // llum - lbase
471         TIntermTyped *ldiff = new TIntermBinary(EOpSub, llum, lbase);
472         // cbase + ldiff
473         TIntermTyped *color = new TIntermBinary(EOpAdd, cbase->deepCopy(), ldiff);
474         // ANGLE_clip_color(color);
475         TIntermSequence clipColorArg = {color};
476         TIntermTyped *result =
477             TIntermAggregate::CreateFunctionCall(*mClipColor->getFunction(), &clipColorArg);
478 
479         TIntermBlock *body = new TIntermBlock;
480         body->appendStatement(new TIntermBranch(EOpReturn, result));
481 
482         const TFunction *function = MakeFunction(mSymbolTable, "ANGLE_set_lum", vec3Type,
483                                                  {&cbase->variable(), &clum->variable()});
484         mSetLum                   = MakeFunctionDefinition(function, body);
485     }
486 
487     // vec3 ANGLE_set_lum_sat(vec3 cbase, vec3 csat, vec3 clum)
488     // {
489     //     float minbase = ANGLE_minv3(cbase);
490     //     float sbase = ANGLE_satv3(cbase);
491     //     float ssat = ANGLE_satv3(csat);
492     //     vec3 color;
493     //     if (sbase > 0.0f)
494     //     {
495     //         color = (cbase - minbase) * ssat / sbase;
496     //     }
497     //     else
498     //     {
499     //         color = vec3(0.0f);
500     //     }
501     //     return ANGLE_set_lum(color, clum);
502     // }
503     {
504         TIntermSymbol *cbase   = MakeVariable(mSymbolTable, "cbase", vec3ParamType);
505         TIntermSymbol *csat    = MakeVariable(mSymbolTable, "csat", vec3ParamType);
506         TIntermSymbol *clum    = MakeVariable(mSymbolTable, "clum", vec3ParamType);
507         TIntermSymbol *minbase = MakeVariable(mSymbolTable, "minbase", floatType);
508         TIntermSymbol *sbase   = MakeVariable(mSymbolTable, "sbase", floatType);
509         TIntermSymbol *ssat    = MakeVariable(mSymbolTable, "ssat", floatType);
510 
511         // ANGLE_minv3(cbase)
512         TIntermSequence cMinArg = {cbase};
513         TIntermTyped *minv3 =
514             TIntermAggregate::CreateFunctionCall(*mMinv3->getFunction(), &cMinArg);
515 
516         // ANGLE_satv3(cbase)
517         TIntermSequence cSatArg = {cbase->deepCopy()};
518         TIntermTyped *baseSatv3 =
519             TIntermAggregate::CreateFunctionCall(*mSatv3->getFunction(), &cSatArg);
520 
521         // ANGLE_satv3(csat)
522         TIntermSequence sSatArg = {csat};
523         TIntermTyped *satSatv3 =
524             TIntermAggregate::CreateFunctionCall(*mSatv3->getFunction(), &sSatArg);
525 
526         TIntermBlock *body = new TIntermBlock;
527         body->appendStatement(CreateTempInitDeclarationNode(&minbase->variable(), minv3));
528         body->appendStatement(CreateTempInitDeclarationNode(&sbase->variable(), baseSatv3));
529         body->appendStatement(CreateTempInitDeclarationNode(&ssat->variable(), satSatv3));
530 
531         // cbase - minbase
532         TIntermTyped *cbaseMinusMinbase = new TIntermBinary(EOpSub, cbase->deepCopy(), minbase);
533         // (cbase - minbase) * ssat
534         TIntermTyped *cbaseMinusMinbaseTimesSsat =
535             new TIntermBinary(EOpVectorTimesScalar, cbaseMinusMinbase, ssat);
536         // (cbase - minbase) * ssat / sbase
537         TIntermTyped *colorSbaseGreaterZero =
538             new TIntermBinary(EOpDiv, cbaseMinusMinbaseTimesSsat, sbase);
539 
540         // sbase > 0.0f
541         TIntermTyped *greaterZero = new TIntermBinary(EOpGreaterThan, sbase->deepCopy(), Float(0));
542 
543         // sbase > 0.0f ? (cbase - minbase) * ssat / sbase : vec3(0.0)
544         TIntermTyped *color =
545             new TIntermTernary(greaterZero, colorSbaseGreaterZero, CreateZeroNode(*vec3Type));
546 
547         // ANGLE_set_lum(color);
548         TIntermSequence setLumArg = {color, clum};
549         TIntermTyped *result =
550             TIntermAggregate::CreateFunctionCall(*mSetLum->getFunction(), &setLumArg);
551 
552         body->appendStatement(new TIntermBranch(EOpReturn, result));
553 
554         const TFunction *function =
555             MakeFunction(mSymbolTable, "ANGLE_set_lum_sat", vec3Type,
556                          {&cbase->variable(), &csat->variable(), &clum->variable()});
557         mSetLumSat = MakeFunctionDefinition(function, body);
558     }
559 }
560 
generateBlendFunctions()561 void Builder::generateBlendFunctions()
562 {
563     const TPrecision precision = mOutputVar->getType().getPrecision();
564 
565     TType *floatParamType = new TType(EbtFloat, precision, EvqParamIn, 1);
566     TType *vec3ParamType  = new TType(EbtFloat, precision, EvqParamIn, 3);
567 
568     gl::BlendEquationBitSet enabledBlendEquations(mAdvancedBlendEquations.bits());
569     for (gl::BlendEquationType equation : enabledBlendEquations)
570     {
571         switch (equation)
572         {
573             case gl::BlendEquationType::Multiply:
574                 // float ANGLE_blend_multiply(float src, float dst)
575                 // {
576                 //     return src * dst;
577                 // }
578                 {
579                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
580                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
581 
582                     // src * dst
583                     TIntermTyped *result = new TIntermBinary(EOpMul, src, dst);
584 
585                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
586                         mSymbolTable, "ANGLE_blend_multiply", result, {src, dst});
587                 }
588                 break;
589             case gl::BlendEquationType::Screen:
590                 // float ANGLE_blend_screen(float src, float dst)
591                 // {
592                 //     return src + dst - src * dst;
593                 // }
594                 {
595                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
596                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
597 
598                     // src + dst
599                     TIntermTyped *sum = new TIntermBinary(EOpAdd, src, dst);
600                     // src * dst
601                     TIntermTyped *mul = new TIntermBinary(EOpMul, src->deepCopy(), dst->deepCopy());
602                     // src + dst - src * dst
603                     TIntermTyped *result = new TIntermBinary(EOpSub, sum, mul);
604 
605                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
606                         mSymbolTable, "ANGLE_blend_screen", result, {src, dst});
607                 }
608                 break;
609             case gl::BlendEquationType::Overlay:
610             case gl::BlendEquationType::Hardlight:
611                 // float ANGLE_blend_overlay(float src, float dst)
612                 // {
613                 //     if (dst <= 0.5f)
614                 //     {
615                 //         return (2.0f * src * dst);
616                 //     }
617                 //     else
618                 //     {
619                 //         return (1.0f - 2.0f * (1.0f - src) * (1.0f - dst));
620                 //     }
621                 //
622                 //     // Equivalently generated as:
623                 //     // return dst <= 0.5f ? 2.*src*dst : 2.*(src+dst) - 2.*src*dst - 1.;
624                 // }
625                 //
626                 // float ANGLE_blend_hardlight(float src, float dst)
627                 // {
628                 //     // Same as overlay, with the |if| checking |src| instead of |dst|.
629                 // }
630                 {
631                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
632                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
633 
634                     // src + dst
635                     TIntermTyped *sum = new TIntermBinary(EOpAdd, src, dst);
636                     // 2 * (src + dst)
637                     TIntermTyped *sum2 = new TIntermBinary(EOpMul, sum, Float(2));
638                     // src * dst
639                     TIntermTyped *mul = new TIntermBinary(EOpMul, src->deepCopy(), dst->deepCopy());
640                     // 2 * src * dst
641                     TIntermTyped *mul2 = new TIntermBinary(EOpMul, mul, Float(2));
642                     // 2 * (src + dst) - 2 * src * dst
643                     TIntermTyped *sum2MinusMul2 = new TIntermBinary(EOpSub, sum2, mul2);
644                     // 2 * (src + dst) - 2 * src * dst - 1
645                     TIntermTyped *sum2MinusMul2Minus1 =
646                         new TIntermBinary(EOpSub, sum2MinusMul2, Float(1));
647 
648                     // dst[src] <= 0.5
649                     TIntermSymbol *conditionSymbol =
650                         equation == gl::BlendEquationType::Overlay ? dst : src;
651                     TIntermTyped *lessHalf = new TIntermBinary(
652                         EOpLessThanEqual, conditionSymbol->deepCopy(), Float(0.5));
653                     // dst[src] <= 0.5f ? ...
654                     TIntermTyped *result =
655                         new TIntermTernary(lessHalf, mul2->deepCopy(), sum2MinusMul2Minus1);
656 
657                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
658                         mSymbolTable,
659                         equation == gl::BlendEquationType::Overlay ? "ANGLE_blend_overlay"
660                                                                    : "ANGLE_blend_hardlight",
661                         result, {src, dst});
662                 }
663                 break;
664             case gl::BlendEquationType::Darken:
665                 // float ANGLE_blend_darken(float src, float dst)
666                 // {
667                 //     return min(src, dst);
668                 // }
669                 {
670                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
671                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
672 
673                     // src * dst
674                     TIntermSequence minArgs = {src, dst};
675                     TIntermTyped *result =
676                         CreateBuiltInFunctionCallNode("min", &minArgs, *mSymbolTable, 100);
677 
678                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
679                         mSymbolTable, "ANGLE_blend_darken", result, {src, dst});
680                 }
681                 break;
682             case gl::BlendEquationType::Lighten:
683                 // float ANGLE_blend_lighten(float src, float dst)
684                 // {
685                 //     return max(src, dst);
686                 // }
687                 {
688                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
689                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
690 
691                     // src * dst
692                     TIntermSequence maxArgs = {src, dst};
693                     TIntermTyped *result =
694                         CreateBuiltInFunctionCallNode("max", &maxArgs, *mSymbolTable, 100);
695 
696                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
697                         mSymbolTable, "ANGLE_blend_lighten", result, {src, dst});
698                 }
699                 break;
700             case gl::BlendEquationType::Colordodge:
701                 // float ANGLE_blend_dodge(float src, float dst)
702                 // {
703                 //     if (dst <= 0.0f)
704                 //     {
705                 //         return 0.0;
706                 //     }
707                 //     else if (src >= 1.0f)   // dst > 0.0
708                 //     {
709                 //         return 1.0;
710                 //     }
711                 //     else                    // dst > 0.0 && src < 1.0
712                 //     {
713                 //         return min(1.0, dst / (1.0 - src));
714                 //     }
715                 //
716                 //     // Equivalently generated as:
717                 //     // return dst <= 0. ? 0. : src >= 1. ? 1. : min(1., dst / (1. - src));
718                 // }
719                 {
720                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
721                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
722 
723                     // 1. - src
724                     TIntermTyped *oneMinusSrc = new TIntermBinary(EOpSub, Float(1), src);
725                     // dst / (1. - src)
726                     TIntermTyped *dstDivOneMinusSrc = new TIntermBinary(EOpDiv, dst, oneMinusSrc);
727                     // min(1., dst / (1. - src))
728                     TIntermSequence minArgs = {Float(1), dstDivOneMinusSrc};
729                     TIntermTyped *result =
730                         CreateBuiltInFunctionCallNode("min", &minArgs, *mSymbolTable, 100);
731 
732                     // src >= 1
733                     TIntermTyped *greaterOne =
734                         new TIntermBinary(EOpGreaterThanEqual, src->deepCopy(), Float(1));
735                     // src >= 1. ? ...
736                     result = new TIntermTernary(greaterOne, Float(1), result);
737 
738                     // dst <= 0
739                     TIntermTyped *lessZero =
740                         new TIntermBinary(EOpLessThanEqual, dst->deepCopy(), Float(0));
741                     // dst <= 0. ? ...
742                     result = new TIntermTernary(lessZero, Float(0), result);
743 
744                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
745                         mSymbolTable, "ANGLE_blend_dodge", result, {src, dst});
746                 }
747                 break;
748             case gl::BlendEquationType::Colorburn:
749                 // float ANGLE_blend_burn(float src, float dst)
750                 // {
751                 //     if (dst >= 1.0f)
752                 //     {
753                 //         return 1.0;
754                 //     }
755                 //     else if (src <= 0.0f)   // dst < 1.0
756                 //     {
757                 //         return 0.0;
758                 //     }
759                 //     else                    // dst < 1.0 && src > 0.0
760                 //     {
761                 //         return 1.0f - min(1.0f, (1.0f - dst) / src);
762                 //     }
763                 //
764                 //     // Equivalently generated as:
765                 //     // return dst >= 1. ? 1. : src <= 0. ? 0. : 1. - min(1., (1. - dst) / src);
766                 // }
767                 {
768                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
769                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
770 
771                     // 1. - dst
772                     TIntermTyped *oneMinusDst = new TIntermBinary(EOpSub, Float(1), dst);
773                     // (1. - dst) / src
774                     TIntermTyped *oneMinusDstDivSrc = new TIntermBinary(EOpDiv, oneMinusDst, src);
775                     // min(1., (1. - dst) / src)
776                     TIntermSequence minArgs = {Float(1), oneMinusDstDivSrc};
777                     TIntermTyped *result =
778                         CreateBuiltInFunctionCallNode("min", &minArgs, *mSymbolTable, 100);
779                     // 1. - min(1., (1. - dst) / src)
780                     result = new TIntermBinary(EOpSub, Float(1), result);
781 
782                     // src <= 0
783                     TIntermTyped *lessZero =
784                         new TIntermBinary(EOpLessThanEqual, src->deepCopy(), Float(0));
785                     // src <= 0. ? ...
786                     result = new TIntermTernary(lessZero, Float(0), result);
787 
788                     // dst >= 1
789                     TIntermTyped *greaterOne =
790                         new TIntermBinary(EOpGreaterThanEqual, dst->deepCopy(), Float(1));
791                     // dst >= 1. ? ...
792                     result = new TIntermTernary(greaterOne, Float(1), result);
793 
794                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
795                         mSymbolTable, "ANGLE_blend_burn", result, {src, dst});
796                 }
797                 break;
798             case gl::BlendEquationType::Softlight:
799                 // float ANGLE_blend_softlight(float src, float dst)
800                 // {
801                 //     if (src <= 0.5f)
802                 //     {
803                 //         return (dst - (1.0f - 2.0f * src) * dst * (1.0f - dst));
804                 //     }
805                 //     else if (dst <= 0.25f)  // src > 0.5
806                 //     {
807                 //         return (dst + (2.0f * src - 1.0f) * dst * ((16.0f * dst - 12.0f) * dst
808                 //         + 3.0f));
809                 //     }
810                 //     else                    // src > 0.5 && dst > 0.25
811                 //     {
812                 //         return (dst + (2.0f * src - 1.0f) * (sqrt(dst) - dst));
813                 //     }
814                 //
815                 //     // Equivalently generated as:
816                 //     // return dst + (2. * src - 1.) * (
817                 //     //            src <= 0.5  ? dst * (1. - dst) :
818                 //     //            dst <= 0.25 ? dst * ((16. * dst - 12.) * dst + 3.) :
819                 //     //                          sqrt(dst) - dst)
820                 // }
821                 {
822                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
823                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
824 
825                     // 2. * src
826                     TIntermTyped *src2 = new TIntermBinary(EOpMul, Float(2), src);
827                     // 2. * src - 1.
828                     TIntermTyped *src2Minus1 = new TIntermBinary(EOpSub, src2, Float(1));
829                     // 1. - dst
830                     TIntermTyped *oneMinusDst = new TIntermBinary(EOpSub, Float(1), dst);
831                     // dst * (1. - dst)
832                     TIntermTyped *dstTimesOneMinusDst =
833                         new TIntermBinary(EOpMul, dst->deepCopy(), oneMinusDst);
834                     // 16. * dst
835                     TIntermTyped *dst16 = new TIntermBinary(EOpMul, Float(16), dst->deepCopy());
836                     // 16. * dst - 12.
837                     TIntermTyped *dst16Minus12 = new TIntermBinary(EOpSub, dst16, Float(12));
838                     // (16. * dst - 12.) * dst
839                     TIntermTyped *dst16Minus12TimesDst =
840                         new TIntermBinary(EOpMul, dst16Minus12, dst->deepCopy());
841                     // (16. * dst - 12.) * dst + 3.
842                     TIntermTyped *dst16Minus12TimesDstPlus3 =
843                         new TIntermBinary(EOpAdd, dst16Minus12TimesDst, Float(3));
844                     // dst * ((16. * dst - 12.) * dst + 3.)
845                     TIntermTyped *dstTimesDst16Minus12TimesDstPlus3 =
846                         new TIntermBinary(EOpMul, dst->deepCopy(), dst16Minus12TimesDstPlus3);
847                     // sqrt(dst)
848                     TIntermSequence sqrtArg = {dst->deepCopy()};
849                     TIntermTyped *sqrtDst =
850                         CreateBuiltInFunctionCallNode("sqrt", &sqrtArg, *mSymbolTable, 100);
851                     // sqrt(dst) - dst
852                     TIntermTyped *sqrtDstMinusDst =
853                         new TIntermBinary(EOpSub, sqrtDst, dst->deepCopy());
854 
855                     // dst <= 0.25
856                     TIntermTyped *lessQuarter =
857                         new TIntermBinary(EOpLessThanEqual, dst->deepCopy(), Float(0.25));
858                     // dst <= 0.25 ? ...
859                     TIntermTyped *result = new TIntermTernary(
860                         lessQuarter, dstTimesDst16Minus12TimesDstPlus3, sqrtDstMinusDst);
861 
862                     // src <= 0.5
863                     TIntermTyped *lessHalf =
864                         new TIntermBinary(EOpLessThanEqual, src->deepCopy(), Float(0.5));
865                     // src <= 0.5 ? ...
866                     result = new TIntermTernary(lessHalf, dstTimesOneMinusDst, result);
867 
868                     // (2. * src - 1.) * ...
869                     result = new TIntermBinary(EOpMul, src2Minus1, result);
870                     // dst + (2. * src - 1.) * ...
871                     result = new TIntermBinary(EOpAdd, dst->deepCopy(), result);
872 
873                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
874                         mSymbolTable, "ANGLE_blend_softlight", result, {src, dst});
875                 }
876                 break;
877             case gl::BlendEquationType::Difference:
878                 // float ANGLE_blend_difference(float src, float dst)
879                 // {
880                 //     return abs(dst - src);
881                 // }
882                 {
883                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
884                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
885 
886                     // dst - src
887                     TIntermTyped *dstMinusSrc = new TIntermBinary(EOpSub, dst, src);
888                     // abs(dst - src)
889                     TIntermSequence absArgs = {dstMinusSrc};
890                     TIntermTyped *result =
891                         CreateBuiltInFunctionCallNode("abs", &absArgs, *mSymbolTable, 100);
892 
893                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
894                         mSymbolTable, "ANGLE_blend_difference", result, {src, dst});
895                 }
896                 break;
897             case gl::BlendEquationType::Exclusion:
898                 // float ANGLE_blend_exclusion(float src, float dst)
899                 // {
900                 //     return src + dst - (2.0f * src * dst);
901                 // }
902                 {
903                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", floatParamType);
904                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", floatParamType);
905 
906                     // src + dst
907                     TIntermTyped *sum = new TIntermBinary(EOpAdd, src, dst);
908                     // src * dst
909                     TIntermTyped *mul = new TIntermBinary(EOpMul, src->deepCopy(), dst->deepCopy());
910                     // 2 * src * dst
911                     TIntermTyped *mul2 = new TIntermBinary(EOpMul, mul, Float(2));
912                     // src + dst - 2 * src * dst
913                     TIntermTyped *result = new TIntermBinary(EOpSub, sum, mul2);
914 
915                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
916                         mSymbolTable, "ANGLE_blend_exclusion", result, {src, dst});
917                 }
918                 break;
919             case gl::BlendEquationType::HslHue:
920                 // vec3 ANGLE_blend_hsl_hue(vec3 src, vec3 dst)
921                 // {
922                 //     return ANGLE_set_lum_sat(src, dst, dst);
923                 // }
924                 {
925                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
926                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
927 
928                     TIntermSequence args = {src, dst, dst->deepCopy()};
929                     TIntermTyped *result =
930                         TIntermAggregate::CreateFunctionCall(*mSetLumSat->getFunction(), &args);
931 
932                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
933                         mSymbolTable, "ANGLE_blend_hsl_hue", result, {src, dst});
934                 }
935                 break;
936             case gl::BlendEquationType::HslSaturation:
937                 // vec3 ANGLE_blend_hsl_saturation(vec3 src, vec3 dst)
938                 // {
939                 //     return ANGLE_set_lum_sat(dst, src, dst);
940                 // }
941                 {
942                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
943                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
944 
945                     TIntermSequence args = {dst, src, dst->deepCopy()};
946                     TIntermTyped *result =
947                         TIntermAggregate::CreateFunctionCall(*mSetLumSat->getFunction(), &args);
948 
949                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
950                         mSymbolTable, "ANGLE_blend_hsl_saturation", result, {src, dst});
951                 }
952                 break;
953             case gl::BlendEquationType::HslColor:
954                 // vec3 ANGLE_blend_hsl_color(vec3 src, vec3 dst)
955                 // {
956                 //     return ANGLE_set_lum(src, dst);
957                 // }
958                 {
959                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
960                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
961 
962                     TIntermSequence args = {src, dst};
963                     TIntermTyped *result =
964                         TIntermAggregate::CreateFunctionCall(*mSetLum->getFunction(), &args);
965 
966                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
967                         mSymbolTable, "ANGLE_blend_hsl_color", result, {src, dst});
968                 }
969                 break;
970             case gl::BlendEquationType::HslLuminosity:
971                 // vec3 ANGLE_blend_hsl_luminosity(vec3 src, vec3 dst)
972                 // {
973                 //     return ANGLE_set_lum(dst, src);
974                 // }
975                 {
976                     TIntermSymbol *src = MakeVariable(mSymbolTable, "src", vec3ParamType);
977                     TIntermSymbol *dst = MakeVariable(mSymbolTable, "dst", vec3ParamType);
978 
979                     TIntermSequence args = {dst, src};
980                     TIntermTyped *result =
981                         TIntermAggregate::CreateFunctionCall(*mSetLum->getFunction(), &args);
982 
983                     mBlendFuncs[equation] = MakeSimpleFunctionDefinition(
984                         mSymbolTable, "ANGLE_blend_hsl_luminosity", result, {src, dst});
985                 }
986                 break;
987             default:
988                 // Only advanced blend equations are possible.
989                 UNREACHABLE();
990         }
991     }
992 }
993 
insertGeneratedFunctions(TIntermBlock * root)994 void Builder::insertGeneratedFunctions(TIntermBlock *root)
995 {
996     // Insert all generated functions in root.  Since they are all inserted at index 0, HSL helpers
997     // are inserted last, and in opposite order.
998     for (TIntermFunctionDefinition *blendFunc : mBlendFuncs)
999     {
1000         if (blendFunc != nullptr)
1001         {
1002             root->insertStatement(0, blendFunc);
1003         }
1004     }
1005     if (mMinv3 != nullptr)
1006     {
1007         root->insertStatement(0, mSetLumSat);
1008         root->insertStatement(0, mSetLum);
1009         root->insertStatement(0, mClipColor);
1010         root->insertStatement(0, mSatv3);
1011         root->insertStatement(0, mLumv3);
1012         root->insertStatement(0, mMaxv3);
1013         root->insertStatement(0, mMinv3);
1014     }
1015 }
1016 
1017 // On some platforms 1.0f is not returned even when the dividend and divisor have the same value.
1018 // In such cases emit 1.0f when the dividend and divisor are equal, else return the divide node
divideFloatNode(TIntermTyped * dividend,TIntermTyped * divisor)1019 TIntermTyped *Builder::divideFloatNode(TIntermTyped *dividend, TIntermTyped *divisor)
1020 {
1021     TIntermBinary *cond = new TIntermBinary(EOpEqual, dividend->deepCopy(), divisor->deepCopy());
1022     TIntermBinary *divideExpr =
1023         new TIntermBinary(EOpDiv, dividend->deepCopy(), divisor->deepCopy());
1024     return new TIntermTernary(cond, CreateFloatNode(1.0f, EbpHigh), divideExpr->deepCopy());
1025 }
1026 
premultiplyAlpha(TIntermBlock * blendBlock,TIntermTyped * var,const char * name)1027 TIntermSymbol *Builder::premultiplyAlpha(TIntermBlock *blendBlock,
1028                                          TIntermTyped *var,
1029                                          const char *name)
1030 {
1031     const TPrecision precision = mOutputVar->getType().getPrecision();
1032     TType *vec3Type            = new TType(EbtFloat, precision, EvqTemporary, 3);
1033 
1034     // symbol = vec3(0)
1035     // If alpha != 0, divide by alpha.  Note that due to precision issues, component == alpha is
1036     // handled especially.  This precision issue affects multiple vendors, and most drivers seem to
1037     // be carrying a similar workaround to pass the CTS test.
1038     TIntermTyped *alpha            = new TIntermSwizzle(var, {3});
1039     TIntermSymbol *symbol          = MakeVariable(mSymbolTable, name, vec3Type);
1040     TIntermTyped *alphaNotZero     = new TIntermBinary(EOpNotEqual, alpha, Float(0));
1041     TIntermBlock *rgbDivAlphaBlock = new TIntermBlock;
1042 
1043     constexpr int kColorChannels = 3;
1044     // For each component:
1045     // symbol.x = (var.x == var.w) ? 1.0 : var.x / var.w
1046     for (int index = 0; index < kColorChannels; index++)
1047     {
1048         TIntermTyped *divideNode        = divideFloatNode(new TIntermSwizzle(var, {index}), alpha);
1049         TIntermBinary *assignDivideNode = new TIntermBinary(
1050             EOpAssign, new TIntermSwizzle(symbol->deepCopy(), {index}), divideNode);
1051         rgbDivAlphaBlock->appendStatement(assignDivideNode);
1052     }
1053 
1054     TIntermIfElse *ifBlock = new TIntermIfElse(alphaNotZero, rgbDivAlphaBlock, nullptr);
1055     blendBlock->appendStatement(
1056         CreateTempInitDeclarationNode(&symbol->variable(), CreateZeroNode(*vec3Type)));
1057     blendBlock->appendStatement(ifBlock);
1058 
1059     return symbol;
1060 }
1061 
GetFirstElementIfArray(TIntermTyped * var)1062 TIntermTyped *GetFirstElementIfArray(TIntermTyped *var)
1063 {
1064     TIntermTyped *element = var;
1065     while (element->getType().isArray())
1066     {
1067         element = new TIntermBinary(EOpIndexDirect, element, CreateIndexNode(0));
1068     }
1069     return element;
1070 }
1071 
generatePreamble(TIntermBlock * blendBlock)1072 void Builder::generatePreamble(TIntermBlock *blendBlock)
1073 {
1074     // Use subpassLoad to read from the input attachment
1075     const TPrecision precision      = mOutputVar->getType().getPrecision();
1076     TType *vec4Type                 = new TType(EbtFloat, precision, EvqTemporary, 4);
1077     TIntermSymbol *subpassInputData = MakeVariable(mSymbolTable, "ANGLELastFragData", vec4Type);
1078 
1079     // Initialize it with subpassLoad() result.
1080     TIntermSequence subpassArguments  = {new TIntermSymbol(mSubpassInputVar)};
1081     TIntermTyped *subpassLoadFuncCall = CreateBuiltInFunctionCallNode(
1082         "subpassLoad", &subpassArguments, *mSymbolTable, kESSLInternalBackendBuiltIns);
1083 
1084     blendBlock->appendStatement(
1085         CreateTempInitDeclarationNode(&subpassInputData->variable(), subpassLoadFuncCall));
1086 
1087     // Get element 0 of the output, if array.
1088     TIntermTyped *output = GetFirstElementIfArray(new TIntermSymbol(mOutputVar));
1089 
1090     // Expand output to vec4, if not already.
1091     uint32_t vecSize = mOutputVar->getType().getNominalSize();
1092     if (vecSize < 4)
1093     {
1094         TIntermSequence vec4Args = {output};
1095         for (uint32_t channel = vecSize; channel < 3; ++channel)
1096         {
1097             vec4Args.push_back(Float(0));
1098         }
1099         vec4Args.push_back(Float(1));
1100         output = TIntermAggregate::CreateConstructor(*vec4Type, &vec4Args);
1101     }
1102 
1103     // Premultiply src and dst.
1104     mSrc = premultiplyAlpha(blendBlock, output, "ANGLE_blend_src");
1105     mDst = premultiplyAlpha(blendBlock, subpassInputData, "ANGLE_blend_dst");
1106 
1107     // Calculate the p coefficients:
1108     TIntermTyped *srcAlpha = new TIntermSwizzle(output->deepCopy(), {3});
1109     TIntermTyped *dstAlpha = new TIntermSwizzle(subpassInputData->deepCopy(), {3});
1110 
1111     // As * Ad
1112     TIntermTyped *AsTimesAd = new TIntermBinary(EOpMul, srcAlpha, dstAlpha);
1113     // As * (1. - Ad)
1114     TIntermTyped *oneMinusAd        = new TIntermBinary(EOpSub, Float(1), dstAlpha->deepCopy());
1115     TIntermTyped *AsTimesOneMinusAd = new TIntermBinary(EOpMul, srcAlpha->deepCopy(), oneMinusAd);
1116     // Ad * (1. - As)
1117     TIntermTyped *oneMinusAs        = new TIntermBinary(EOpSub, Float(1), srcAlpha->deepCopy());
1118     TIntermTyped *AdTimesOneMinusAs = new TIntermBinary(EOpMul, dstAlpha->deepCopy(), oneMinusAs);
1119 
1120     mP0 = MakeVariable(mSymbolTable, "ANGLE_blend_p0", &srcAlpha->getType());
1121     mP1 = MakeVariable(mSymbolTable, "ANGLE_blend_p1", &srcAlpha->getType());
1122     mP2 = MakeVariable(mSymbolTable, "ANGLE_blend_p2", &srcAlpha->getType());
1123 
1124     blendBlock->appendStatement(CreateTempInitDeclarationNode(&mP0->variable(), AsTimesAd));
1125     blendBlock->appendStatement(CreateTempInitDeclarationNode(&mP1->variable(), AsTimesOneMinusAd));
1126     blendBlock->appendStatement(CreateTempInitDeclarationNode(&mP2->variable(), AdTimesOneMinusAs));
1127 }
1128 
generateEquationSwitch(TIntermBlock * blendBlock)1129 void Builder::generateEquationSwitch(TIntermBlock *blendBlock)
1130 {
1131     const TPrecision precision = mOutputVar->getType().getPrecision();
1132 
1133     TType *vec3Type = new TType(EbtFloat, precision, EvqTemporary, 3);
1134     TType *vec4Type = new TType(EbtFloat, precision, EvqTemporary, 4);
1135 
1136     // The following code is generated:
1137     //
1138     // vec3 f;
1139     // swtich (equation)
1140     // {
1141     //    case A:
1142     //       f = ANGLE_blend_a(..);
1143     //       break;
1144     //    case B:
1145     //       f = ANGLE_blend_b(..);
1146     //       break;
1147     //    ...
1148     // }
1149     //
1150     // vec3 rgb = f * p0 + src * p1 + dst * p2
1151     // float a = p0 + p1 + p2
1152     //
1153     // output = vec4(rgb, a);
1154 
1155     TIntermSymbol *f = MakeVariable(mSymbolTable, "ANGLE_f", vec3Type);
1156     blendBlock->appendStatement(CreateTempDeclarationNode(&f->variable()));
1157 
1158     TIntermBlock *switchBody = new TIntermBlock;
1159 
1160     gl::BlendEquationBitSet enabledBlendEquations(mAdvancedBlendEquations.bits());
1161     for (gl::BlendEquationType equation : enabledBlendEquations)
1162     {
1163         switchBody->appendStatement(
1164             new TIntermCase(CreateUIntNode(static_cast<uint32_t>(equation))));
1165 
1166         // HSL equations call the blend function with all channels.  Non-HSL equations call it per
1167         // component.
1168         if (equation < gl::BlendEquationType::HslHue)
1169         {
1170             TIntermSequence constructorArgs;
1171             for (int channel = 0; channel < 3; ++channel)
1172             {
1173                 TIntermTyped *srcChannel = new TIntermSwizzle(mSrc->deepCopy(), {channel});
1174                 TIntermTyped *dstChannel = new TIntermSwizzle(mDst->deepCopy(), {channel});
1175 
1176                 TIntermSequence args = {srcChannel, dstChannel};
1177                 constructorArgs.push_back(TIntermAggregate::CreateFunctionCall(
1178                     *mBlendFuncs[equation]->getFunction(), &args));
1179             }
1180 
1181             TIntermTyped *constructor =
1182                 TIntermAggregate::CreateConstructor(*vec3Type, &constructorArgs);
1183             switchBody->appendStatement(new TIntermBinary(EOpAssign, f->deepCopy(), constructor));
1184         }
1185         else
1186         {
1187             TIntermSequence args = {mSrc->deepCopy(), mDst->deepCopy()};
1188             TIntermTyped *blendCall =
1189                 TIntermAggregate::CreateFunctionCall(*mBlendFuncs[equation]->getFunction(), &args);
1190 
1191             switchBody->appendStatement(new TIntermBinary(EOpAssign, f->deepCopy(), blendCall));
1192         }
1193 
1194         switchBody->appendStatement(new TIntermBranch(EOpBreak, nullptr));
1195     }
1196 
1197     // A driver uniform is used to communicate the blend equation to use.
1198     TIntermTyped *equationUniform = mDriverUniforms->getAdvancedBlendEquation();
1199 
1200     blendBlock->appendStatement(new TIntermSwitch(equationUniform, switchBody));
1201 
1202     // Calculate the final blend according to the following formula:
1203     //
1204     //     RGB = f(src, dst) * p0 + src * p1 + dst * p2
1205     //       A = p0 + p1 + p2
1206 
1207     // f * p0
1208     TIntermTyped *fTimesP0 = new TIntermBinary(EOpVectorTimesScalar, f, mP0);
1209     // src * p1
1210     TIntermTyped *srcTimesP1 = new TIntermBinary(EOpVectorTimesScalar, mSrc, mP1);
1211     // dst * p2
1212     TIntermTyped *dstTimesP2 = new TIntermBinary(EOpVectorTimesScalar, mDst, mP2);
1213     // f * p0 + src * p1 + dst * p2
1214     TIntermTyped *rgb =
1215         new TIntermBinary(EOpAdd, new TIntermBinary(EOpAdd, fTimesP0, srcTimesP1), dstTimesP2);
1216 
1217     // p0 + p1 + p2
1218     TIntermTyped *a = new TIntermBinary(
1219         EOpAdd, new TIntermBinary(EOpAdd, mP0->deepCopy(), mP1->deepCopy()), mP2->deepCopy());
1220 
1221     // Intialize the output with vec4(RGB, A)
1222     TIntermSequence rgbaArgs  = {rgb, a};
1223     TIntermTyped *blendResult = TIntermAggregate::CreateConstructor(*vec4Type, &rgbaArgs);
1224 
1225     // If the output has fewer than four channels, swizzle the results
1226     uint32_t vecSize = mOutputVar->getType().getNominalSize();
1227     if (vecSize < 4)
1228     {
1229         TVector<int> swizzle = {0, 1, 2, 3};
1230         swizzle.resize(vecSize);
1231         blendResult = new TIntermSwizzle(blendResult, swizzle);
1232     }
1233 
1234     TIntermTyped *output = GetFirstElementIfArray(new TIntermSymbol(mOutputVar));
1235 
1236     blendBlock->appendStatement(new TIntermBinary(EOpAssign, output, blendResult));
1237 }
1238 }  // anonymous namespace
1239 
EmulateAdvancedBlendEquations(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const AdvancedBlendEquations & advancedBlendEquations,const DriverUniform * driverUniforms,InputAttachmentMap * inputAttachmentMapOut)1240 bool EmulateAdvancedBlendEquations(TCompiler *compiler,
1241                                    TIntermBlock *root,
1242                                    TSymbolTable *symbolTable,
1243                                    const AdvancedBlendEquations &advancedBlendEquations,
1244                                    const DriverUniform *driverUniforms,
1245                                    InputAttachmentMap *inputAttachmentMapOut)
1246 {
1247     Builder builder(compiler, symbolTable, advancedBlendEquations, driverUniforms,
1248                     inputAttachmentMapOut);
1249     return builder.build(root);
1250 }  // namespace
1251 
1252 }  // namespace sh
1253