xref: /aosp_15_r20/external/angle/src/compiler/translator/hlsl/OutputHLSL.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2002 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/hlsl/OutputHLSL.h"
8 
9 #include <stdio.h>
10 #include <algorithm>
11 #include <cfloat>
12 
13 #include "common/angleutils.h"
14 #include "common/debug.h"
15 #include "common/utilities.h"
16 #include "compiler/translator/BuiltInFunctionEmulator.h"
17 #include "compiler/translator/InfoSink.h"
18 #include "compiler/translator/StaticType.h"
19 #include "compiler/translator/blocklayout.h"
20 #include "compiler/translator/hlsl/AtomicCounterFunctionHLSL.h"
21 #include "compiler/translator/hlsl/BuiltInFunctionEmulatorHLSL.h"
22 #include "compiler/translator/hlsl/ImageFunctionHLSL.h"
23 #include "compiler/translator/hlsl/ResourcesHLSL.h"
24 #include "compiler/translator/hlsl/StructureHLSL.h"
25 #include "compiler/translator/hlsl/TextureFunctionHLSL.h"
26 #include "compiler/translator/hlsl/TranslatorHLSL.h"
27 #include "compiler/translator/hlsl/UtilsHLSL.h"
28 #include "compiler/translator/tree_ops/hlsl/RemoveSwitchFallThrough.h"
29 #include "compiler/translator/tree_util/FindSymbolNode.h"
30 #include "compiler/translator/tree_util/NodeSearch.h"
31 #include "compiler/translator/util.h"
32 
33 namespace sh
34 {
35 
36 namespace
37 {
38 
39 constexpr const char kImage2DFunctionString[] = "// @@ IMAGE2D DECLARATION FUNCTION STRING @@";
40 
ArrayHelperFunctionName(const char * prefix,const TType & type)41 TString ArrayHelperFunctionName(const char *prefix, const TType &type)
42 {
43     TStringStream fnName = sh::InitializeStream<TStringStream>();
44     fnName << prefix << "_";
45     if (type.isArray())
46     {
47         for (unsigned int arraySize : type.getArraySizes())
48         {
49             fnName << arraySize << "_";
50         }
51     }
52     fnName << TypeString(type);
53     return fnName.str();
54 }
55 
IsDeclarationWrittenOut(TIntermDeclaration * node)56 bool IsDeclarationWrittenOut(TIntermDeclaration *node)
57 {
58     TIntermSequence *sequence = node->getSequence();
59     TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
60     ASSERT(sequence->size() == 1);
61     ASSERT(variable);
62     return (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal ||
63             variable->getQualifier() == EvqConst || variable->getQualifier() == EvqShared);
64 }
65 
IsInStd140UniformBlock(TIntermTyped * node)66 bool IsInStd140UniformBlock(TIntermTyped *node)
67 {
68     TIntermBinary *binaryNode = node->getAsBinaryNode();
69 
70     if (binaryNode)
71     {
72         return IsInStd140UniformBlock(binaryNode->getLeft());
73     }
74 
75     const TType &type = node->getType();
76 
77     if (type.getQualifier() == EvqUniform)
78     {
79         // determine if we are in the standard layout
80         const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
81         if (interfaceBlock)
82         {
83             return (interfaceBlock->blockStorage() == EbsStd140);
84         }
85     }
86 
87     return false;
88 }
89 
GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped * node)90 const TInterfaceBlock *GetInterfaceBlockOfUniformBlockNearestIndexOperator(TIntermTyped *node)
91 {
92     const TIntermBinary *binaryNode = node->getAsBinaryNode();
93     if (binaryNode)
94     {
95         if (binaryNode->getOp() == EOpIndexDirectInterfaceBlock)
96         {
97             return binaryNode->getLeft()->getType().getInterfaceBlock();
98         }
99     }
100 
101     const TIntermSymbol *symbolNode = node->getAsSymbolNode();
102     if (symbolNode)
103     {
104         const TVariable &variable = symbolNode->variable();
105         const TType &variableType = variable.getType();
106 
107         if (variableType.getQualifier() == EvqUniform &&
108             variable.symbolType() == SymbolType::UserDefined)
109         {
110             return variableType.getInterfaceBlock();
111         }
112     }
113 
114     return nullptr;
115 }
116 
GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)117 const char *GetHLSLAtomicFunctionStringAndLeftParenthesis(TOperator op)
118 {
119     switch (op)
120     {
121         case EOpAtomicAdd:
122             return "InterlockedAdd(";
123         case EOpAtomicMin:
124             return "InterlockedMin(";
125         case EOpAtomicMax:
126             return "InterlockedMax(";
127         case EOpAtomicAnd:
128             return "InterlockedAnd(";
129         case EOpAtomicOr:
130             return "InterlockedOr(";
131         case EOpAtomicXor:
132             return "InterlockedXor(";
133         case EOpAtomicExchange:
134             return "InterlockedExchange(";
135         case EOpAtomicCompSwap:
136             return "InterlockedCompareExchange(";
137         default:
138             UNREACHABLE();
139             return "";
140     }
141 }
142 
IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary & node)143 bool IsAtomicFunctionForSharedVariableDirectAssign(const TIntermBinary &node)
144 {
145     TIntermAggregate *aggregateNode = node.getRight()->getAsAggregate();
146     if (aggregateNode == nullptr)
147     {
148         return false;
149     }
150 
151     if (node.getOp() == EOpAssign && BuiltInGroup::IsAtomicMemory(aggregateNode->getOp()))
152     {
153         return !IsInShaderStorageBlock((*aggregateNode->getSequence())[0]->getAsTyped()) &&
154                !IsInShaderStorageBlock(node.getLeft());
155     }
156 
157     return false;
158 }
159 
160 const char *kZeros       = "_ANGLE_ZEROS_";
161 constexpr int kZeroCount = 256;
DefineZeroArray()162 std::string DefineZeroArray()
163 {
164     std::stringstream ss = sh::InitializeStream<std::stringstream>();
165     // For 'static', if the declaration does not include an initializer, the value is set to zero.
166     // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-variable-syntax
167     ss << "static uint " << kZeros << "[" << kZeroCount << "];\n";
168     return ss.str();
169 }
170 
GetZeroInitializer(size_t size)171 std::string GetZeroInitializer(size_t size)
172 {
173     std::stringstream ss = sh::InitializeStream<std::stringstream>();
174     size_t quotient      = size / kZeroCount;
175     size_t reminder      = size % kZeroCount;
176 
177     for (size_t i = 0; i < quotient; ++i)
178     {
179         if (i != 0)
180         {
181             ss << ", ";
182         }
183         ss << kZeros;
184     }
185 
186     for (size_t i = 0; i < reminder; ++i)
187     {
188         if (quotient != 0 || i != 0)
189         {
190             ss << ", ";
191         }
192         ss << "0";
193     }
194 
195     return ss.str();
196 }
197 
IsFlatInterpolant(TIntermTyped * node)198 bool IsFlatInterpolant(TIntermTyped *node)
199 {
200     TIntermTyped *interpolant = node->getAsBinaryNode() ? node->getAsBinaryNode()->getLeft() : node;
201     return interpolant->getType().getQualifier() == EvqFlatIn;
202 }
203 
204 }  // anonymous namespace
205 
TReferencedBlock(const TInterfaceBlock * aBlock,const TVariable * aInstanceVariable)206 TReferencedBlock::TReferencedBlock(const TInterfaceBlock *aBlock,
207                                    const TVariable *aInstanceVariable)
208     : block(aBlock), instanceVariable(aInstanceVariable)
209 {}
210 
needStructMapping(TIntermTyped * node)211 bool OutputHLSL::needStructMapping(TIntermTyped *node)
212 {
213     ASSERT(node->getBasicType() == EbtStruct);
214     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
215     {
216         TIntermNode *ancestor               = getAncestorNode(n);
217         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
218         if (ancestorBinary)
219         {
220             switch (ancestorBinary->getOp())
221             {
222                 case EOpIndexDirectStruct:
223                 {
224                     const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
225                     const TIntermConstantUnion *index =
226                         ancestorBinary->getRight()->getAsConstantUnion();
227                     const TField *field = structure->fields()[index->getIConst(0)];
228                     if (field->type()->getStruct() == nullptr)
229                     {
230                         return false;
231                     }
232                     break;
233                 }
234                 case EOpIndexDirect:
235                 case EOpIndexIndirect:
236                     break;
237                 default:
238                     return true;
239             }
240         }
241         else
242         {
243             const TIntermAggregate *ancestorAggregate = ancestor->getAsAggregate();
244             if (ancestorAggregate)
245             {
246                 return true;
247             }
248             return false;
249         }
250     }
251     return true;
252 }
253 
writeFloat(TInfoSinkBase & out,float f)254 void OutputHLSL::writeFloat(TInfoSinkBase &out, float f)
255 {
256     // This is known not to work for NaN on all drivers but make the best effort to output NaNs
257     // regardless.
258     if ((gl::isInf(f) || gl::isNaN(f)) && mShaderVersion >= 300 &&
259         mOutputType == SH_HLSL_4_1_OUTPUT)
260     {
261         out << "asfloat(" << gl::bitCast<uint32_t>(f) << "u)";
262     }
263     else
264     {
265         out << std::min(FLT_MAX, std::max(-FLT_MAX, f));
266     }
267 }
268 
writeSingleConstant(TInfoSinkBase & out,const TConstantUnion * const constUnion)269 void OutputHLSL::writeSingleConstant(TInfoSinkBase &out, const TConstantUnion *const constUnion)
270 {
271     ASSERT(constUnion != nullptr);
272     switch (constUnion->getType())
273     {
274         case EbtFloat:
275             writeFloat(out, constUnion->getFConst());
276             break;
277         case EbtInt:
278             out << constUnion->getIConst();
279             break;
280         case EbtUInt:
281             out << constUnion->getUConst();
282             break;
283         case EbtBool:
284             out << constUnion->getBConst();
285             break;
286         default:
287             UNREACHABLE();
288     }
289 }
290 
writeConstantUnionArray(TInfoSinkBase & out,const TConstantUnion * const constUnion,const size_t size)291 const TConstantUnion *OutputHLSL::writeConstantUnionArray(TInfoSinkBase &out,
292                                                           const TConstantUnion *const constUnion,
293                                                           const size_t size)
294 {
295     const TConstantUnion *constUnionIterated = constUnion;
296     for (size_t i = 0; i < size; i++, constUnionIterated++)
297     {
298         writeSingleConstant(out, constUnionIterated);
299 
300         if (i != size - 1)
301         {
302             out << ", ";
303         }
304     }
305     return constUnionIterated;
306 }
307 
OutputHLSL(sh::GLenum shaderType,ShShaderSpec shaderSpec,int shaderVersion,const TExtensionBehavior & extensionBehavior,const char * sourcePath,ShShaderOutput outputType,int numRenderTargets,int maxDualSourceDrawBuffers,const std::vector<ShaderVariable> & uniforms,const ShCompileOptions & compileOptions,sh::WorkGroupSize workGroupSize,TSymbolTable * symbolTable,PerformanceDiagnostics * perfDiagnostics,const std::map<int,const TInterfaceBlock * > & uniformBlockOptimizedMap,const std::vector<InterfaceBlock> & shaderStorageBlocks,uint8_t clipDistanceSize,uint8_t cullDistanceSize,bool isEarlyFragmentTestsSpecified)308 OutputHLSL::OutputHLSL(sh::GLenum shaderType,
309                        ShShaderSpec shaderSpec,
310                        int shaderVersion,
311                        const TExtensionBehavior &extensionBehavior,
312                        const char *sourcePath,
313                        ShShaderOutput outputType,
314                        int numRenderTargets,
315                        int maxDualSourceDrawBuffers,
316                        const std::vector<ShaderVariable> &uniforms,
317                        const ShCompileOptions &compileOptions,
318                        sh::WorkGroupSize workGroupSize,
319                        TSymbolTable *symbolTable,
320                        PerformanceDiagnostics *perfDiagnostics,
321                        const std::map<int, const TInterfaceBlock *> &uniformBlockOptimizedMap,
322                        const std::vector<InterfaceBlock> &shaderStorageBlocks,
323                        uint8_t clipDistanceSize,
324                        uint8_t cullDistanceSize,
325                        bool isEarlyFragmentTestsSpecified)
326     : TIntermTraverser(true, true, true, symbolTable),
327       mShaderType(shaderType),
328       mShaderSpec(shaderSpec),
329       mShaderVersion(shaderVersion),
330       mExtensionBehavior(extensionBehavior),
331       mSourcePath(sourcePath),
332       mOutputType(outputType),
333       mCompileOptions(compileOptions),
334       mInsideFunction(false),
335       mInsideMain(false),
336       mUniformBlockOptimizedMap(uniformBlockOptimizedMap),
337       mNumRenderTargets(numRenderTargets),
338       mMaxDualSourceDrawBuffers(maxDualSourceDrawBuffers),
339       mCurrentFunctionMetadata(nullptr),
340       mWorkGroupSize(workGroupSize),
341       mPerfDiagnostics(perfDiagnostics),
342       mClipDistanceSize(clipDistanceSize),
343       mCullDistanceSize(cullDistanceSize),
344       mIsEarlyFragmentTestsSpecified(isEarlyFragmentTestsSpecified),
345       mNeedStructMapping(false)
346 {
347     mUsesFragColor        = false;
348     mUsesFragData         = false;
349     mUsesDepthRange       = false;
350     mUsesFragCoord        = false;
351     mUsesPointCoord       = false;
352     mUsesFrontFacing      = false;
353     mUsesHelperInvocation = false;
354     mUsesPointSize        = false;
355     mUsesInstanceID       = false;
356     mHasMultiviewExtensionEnabled =
357         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview) ||
358         IsExtensionEnabled(mExtensionBehavior, TExtension::OVR_multiview2);
359     mUsesViewID                  = false;
360     mUsesVertexID                = false;
361     mUsesFragDepth               = false;
362     mUsesSampleID                = false;
363     mUsesSamplePosition          = false;
364     mUsesSampleMaskIn            = false;
365     mUsesSampleMask              = false;
366     mUsesNumSamples              = false;
367     mUsesNumWorkGroups           = false;
368     mUsesWorkGroupID             = false;
369     mUsesLocalInvocationID       = false;
370     mUsesGlobalInvocationID      = false;
371     mUsesLocalInvocationIndex    = false;
372     mUsesXor                     = false;
373     mUsesDiscardRewriting        = false;
374     mUsesNestedBreak             = false;
375     mRequiresIEEEStrictCompiling = false;
376     mUseZeroArray                = false;
377     mUsesSecondaryColor          = false;
378 
379     mDepthLayout = EdUnspecified;
380 
381     mUniqueIndex = 0;
382 
383     mOutputLod0Function      = false;
384     mInsideDiscontinuousLoop = false;
385     mNestedLoopDepth         = 0;
386 
387     mExcessiveLoopIndex = nullptr;
388 
389     mStructureHLSL       = new StructureHLSL;
390     mTextureFunctionHLSL = new TextureFunctionHLSL;
391     mImageFunctionHLSL   = new ImageFunctionHLSL;
392     mAtomicCounterFunctionHLSL =
393         new AtomicCounterFunctionHLSL(compileOptions.forceAtomicValueResolution);
394 
395     unsigned int firstUniformRegister = compileOptions.skipD3DConstantRegisterZero ? 1u : 0u;
396     mResourcesHLSL = new ResourcesHLSL(mStructureHLSL, outputType, uniforms, firstUniformRegister);
397 
398     if (mOutputType == SH_HLSL_3_0_OUTPUT)
399     {
400         // Fragment shaders need dx_DepthRange, dx_ViewCoords, dx_DepthFront,
401         // and dx_FragCoordOffset.
402         // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and
403         // dx_ViewAdjust.
404         if (mShaderType == GL_VERTEX_SHADER)
405         {
406             mResourcesHLSL->reserveUniformRegisters(3);
407         }
408         else
409         {
410             mResourcesHLSL->reserveUniformRegisters(4);
411         }
412     }
413 
414     // Reserve registers for the default uniform block and driver constants
415     mResourcesHLSL->reserveUniformBlockRegisters(2);
416 
417     mSSBOOutputHLSL = new ShaderStorageBlockOutputHLSL(this, mResourcesHLSL, shaderStorageBlocks);
418 }
419 
~OutputHLSL()420 OutputHLSL::~OutputHLSL()
421 {
422     SafeDelete(mSSBOOutputHLSL);
423     SafeDelete(mStructureHLSL);
424     SafeDelete(mResourcesHLSL);
425     SafeDelete(mTextureFunctionHLSL);
426     SafeDelete(mImageFunctionHLSL);
427     SafeDelete(mAtomicCounterFunctionHLSL);
428     for (auto &eqFunction : mStructEqualityFunctions)
429     {
430         SafeDelete(eqFunction);
431     }
432     for (auto &eqFunction : mArrayEqualityFunctions)
433     {
434         SafeDelete(eqFunction);
435     }
436 }
437 
output(TIntermNode * treeRoot,TInfoSinkBase & objSink)438 void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
439 {
440     BuiltInFunctionEmulator builtInFunctionEmulator;
441     InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
442     if (mCompileOptions.emulateIsnanFloatFunction)
443     {
444         InitBuiltInIsnanFunctionEmulatorForHLSLWorkarounds(&builtInFunctionEmulator,
445                                                            mShaderVersion);
446     }
447 
448     builtInFunctionEmulator.markBuiltInFunctionsForEmulation(treeRoot);
449 
450     // Now that we are done changing the AST, do the analyses need for HLSL generation
451     CallDAG::InitResult success = mCallDag.init(treeRoot, nullptr);
452     ASSERT(success == CallDAG::INITDAG_SUCCESS);
453     mASTMetadataList = CreateASTMetadataHLSL(treeRoot, mCallDag);
454 
455     const std::vector<MappedStruct> std140Structs = FlagStd140Structs(treeRoot);
456     // TODO(oetuaho): The std140Structs could be filtered based on which ones actually get used in
457     // the shader code. When we add shader storage blocks we might also consider an alternative
458     // solution, since the struct mapping won't work very well for shader storage blocks.
459 
460     // Output the body and footer first to determine what has to go in the header
461     mInfoSinkStack.push(&mBody);
462     treeRoot->traverse(this);
463     mInfoSinkStack.pop();
464 
465     mInfoSinkStack.push(&mFooter);
466     mInfoSinkStack.pop();
467 
468     mInfoSinkStack.push(&mHeader);
469     header(mHeader, std140Structs, &builtInFunctionEmulator);
470     mInfoSinkStack.pop();
471 
472     objSink << mHeader.c_str();
473     objSink << mBody.c_str();
474     objSink << mFooter.c_str();
475 
476     builtInFunctionEmulator.cleanup();
477 }
478 
getShaderStorageBlockRegisterMap() const479 const std::map<std::string, unsigned int> &OutputHLSL::getShaderStorageBlockRegisterMap() const
480 {
481     return mResourcesHLSL->getShaderStorageBlockRegisterMap();
482 }
483 
getUniformBlockRegisterMap() const484 const std::map<std::string, unsigned int> &OutputHLSL::getUniformBlockRegisterMap() const
485 {
486     return mResourcesHLSL->getUniformBlockRegisterMap();
487 }
488 
getUniformBlockUseStructuredBufferMap() const489 const std::map<std::string, bool> &OutputHLSL::getUniformBlockUseStructuredBufferMap() const
490 {
491     return mResourcesHLSL->getUniformBlockUseStructuredBufferMap();
492 }
493 
getUniformRegisterMap() const494 const std::map<std::string, unsigned int> &OutputHLSL::getUniformRegisterMap() const
495 {
496     return mResourcesHLSL->getUniformRegisterMap();
497 }
498 
getReadonlyImage2DRegisterIndex() const499 unsigned int OutputHLSL::getReadonlyImage2DRegisterIndex() const
500 {
501     return mResourcesHLSL->getReadonlyImage2DRegisterIndex();
502 }
503 
getImage2DRegisterIndex() const504 unsigned int OutputHLSL::getImage2DRegisterIndex() const
505 {
506     return mResourcesHLSL->getImage2DRegisterIndex();
507 }
508 
getUsedImage2DFunctionNames() const509 const std::set<std::string> &OutputHLSL::getUsedImage2DFunctionNames() const
510 {
511     return mImageFunctionHLSL->getUsedImage2DFunctionNames();
512 }
513 
structInitializerString(int indent,const TType & type,const TString & name) const514 TString OutputHLSL::structInitializerString(int indent,
515                                             const TType &type,
516                                             const TString &name) const
517 {
518     TString init;
519 
520     TString indentString;
521     for (int spaces = 0; spaces < indent; spaces++)
522     {
523         indentString += "    ";
524     }
525 
526     if (type.isArray())
527     {
528         init += indentString + "{\n";
529         for (unsigned int arrayIndex = 0u; arrayIndex < type.getOutermostArraySize(); ++arrayIndex)
530         {
531             TStringStream indexedString = sh::InitializeStream<TStringStream>();
532             indexedString << name << "[" << arrayIndex << "]";
533             TType elementType = type;
534             elementType.toArrayElementType();
535             init += structInitializerString(indent + 1, elementType, indexedString.str());
536             if (arrayIndex < type.getOutermostArraySize() - 1)
537             {
538                 init += ",";
539             }
540             init += "\n";
541         }
542         init += indentString + "}";
543     }
544     else if (type.getBasicType() == EbtStruct)
545     {
546         init += indentString + "{\n";
547         const TStructure &structure = *type.getStruct();
548         const TFieldList &fields    = structure.fields();
549         for (unsigned int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++)
550         {
551             const TField &field      = *fields[fieldIndex];
552             const TString &fieldName = name + "." + Decorate(field.name());
553             const TType &fieldType   = *field.type();
554 
555             init += structInitializerString(indent + 1, fieldType, fieldName);
556             if (fieldIndex < fields.size() - 1)
557             {
558                 init += ",";
559             }
560             init += "\n";
561         }
562         init += indentString + "}";
563     }
564     else
565     {
566         init += indentString + name;
567     }
568 
569     return init;
570 }
571 
generateStructMapping(const std::vector<MappedStruct> & std140Structs) const572 TString OutputHLSL::generateStructMapping(const std::vector<MappedStruct> &std140Structs) const
573 {
574     TString mappedStructs;
575 
576     for (auto &mappedStruct : std140Structs)
577     {
578         const TInterfaceBlock *interfaceBlock =
579             mappedStruct.blockDeclarator->getType().getInterfaceBlock();
580         TQualifier qualifier = mappedStruct.blockDeclarator->getType().getQualifier();
581         switch (qualifier)
582         {
583             case EvqUniform:
584                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
585                 {
586                     continue;
587                 }
588                 break;
589             case EvqBuffer:
590                 continue;
591             default:
592                 UNREACHABLE();
593                 return mappedStructs;
594         }
595 
596         unsigned int instanceCount = 1u;
597         bool isInstanceArray       = mappedStruct.blockDeclarator->isArray();
598         if (isInstanceArray)
599         {
600             instanceCount = mappedStruct.blockDeclarator->getOutermostArraySize();
601         }
602 
603         for (unsigned int instanceArrayIndex = 0; instanceArrayIndex < instanceCount;
604              ++instanceArrayIndex)
605         {
606             TString originalName;
607             TString mappedName("map");
608 
609             if (mappedStruct.blockDeclarator->variable().symbolType() != SymbolType::Empty)
610             {
611                 const ImmutableString &instanceName =
612                     mappedStruct.blockDeclarator->variable().name();
613                 unsigned int instanceStringArrayIndex = GL_INVALID_INDEX;
614                 if (isInstanceArray)
615                     instanceStringArrayIndex = instanceArrayIndex;
616                 TString instanceString = mResourcesHLSL->InterfaceBlockInstanceString(
617                     instanceName, instanceStringArrayIndex);
618                 originalName += instanceString;
619                 mappedName += instanceString;
620                 originalName += ".";
621                 mappedName += "_";
622             }
623 
624             TString fieldName = Decorate(mappedStruct.field->name());
625             originalName += fieldName;
626             mappedName += fieldName;
627 
628             TType *structType = mappedStruct.field->type();
629             mappedStructs +=
630                 "static " + Decorate(structType->getStruct()->name()) + " " + mappedName;
631 
632             if (structType->isArray())
633             {
634                 mappedStructs += ArrayString(*mappedStruct.field->type()).data();
635             }
636 
637             mappedStructs += " =\n";
638             mappedStructs += structInitializerString(0, *structType, originalName);
639             mappedStructs += ";\n";
640         }
641     }
642     return mappedStructs;
643 }
644 
writeReferencedAttributes(TInfoSinkBase & out) const645 void OutputHLSL::writeReferencedAttributes(TInfoSinkBase &out) const
646 {
647     for (const auto &attribute : mReferencedAttributes)
648     {
649         const TType &type           = attribute.second->getType();
650         const ImmutableString &name = attribute.second->name();
651 
652         out << "static " << TypeString(type) << " " << Decorate(name) << ArrayString(type) << " = "
653             << zeroInitializer(type) << ";\n";
654     }
655 }
656 
writeReferencedVaryings(TInfoSinkBase & out) const657 void OutputHLSL::writeReferencedVaryings(TInfoSinkBase &out) const
658 {
659     for (const auto &varying : mReferencedVaryings)
660     {
661         const TType &type = varying.second->getType();
662 
663         // Program linking depends on this exact format
664         out << "static " << InterpolationString(type.getQualifier()) << " " << TypeString(type)
665             << " " << DecorateVariableIfNeeded(*varying.second) << ArrayString(type) << " = "
666             << zeroInitializer(type) << ";\n";
667     }
668 }
669 
header(TInfoSinkBase & out,const std::vector<MappedStruct> & std140Structs,const BuiltInFunctionEmulator * builtInFunctionEmulator) const670 void OutputHLSL::header(TInfoSinkBase &out,
671                         const std::vector<MappedStruct> &std140Structs,
672                         const BuiltInFunctionEmulator *builtInFunctionEmulator) const
673 {
674     TString mappedStructs;
675     if (mNeedStructMapping)
676     {
677         mappedStructs = generateStructMapping(std140Structs);
678     }
679 
680     // Suppress some common warnings:
681     // 3556 : Integer divides might be much slower, try using uints if possible.
682     // 3571 : The pow(f, e) intrinsic function won't work for negative f, use abs(f) or
683     //        conditionally handle negative values if you expect them.
684     out << "#pragma warning( disable: 3556 3571 )\n";
685 
686     out << mStructureHLSL->structsHeader();
687 
688     mResourcesHLSL->uniformsHeader(out, mOutputType, mReferencedUniforms, mSymbolTable);
689     out << mResourcesHLSL->uniformBlocksHeader(mReferencedUniformBlocks, mUniformBlockOptimizedMap);
690     mSSBOOutputHLSL->writeShaderStorageBlocksHeader(mShaderType, out);
691 
692     if (!mEqualityFunctions.empty())
693     {
694         out << "\n// Equality functions\n\n";
695         for (const auto &eqFunction : mEqualityFunctions)
696         {
697             out << eqFunction->functionDefinition << "\n";
698         }
699     }
700     if (!mArrayAssignmentFunctions.empty())
701     {
702         out << "\n// Assignment functions\n\n";
703         for (const auto &assignmentFunction : mArrayAssignmentFunctions)
704         {
705             out << assignmentFunction.functionDefinition << "\n";
706         }
707     }
708     if (!mArrayConstructIntoFunctions.empty())
709     {
710         out << "\n// Array constructor functions\n\n";
711         for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
712         {
713             out << constructIntoFunction.functionDefinition << "\n";
714         }
715     }
716     if (!mFlatEvaluateFunctions.empty())
717     {
718         out << "\n// Evaluate* functions for flat inputs\n\n";
719         for (const auto &flatEvaluateFunction : mFlatEvaluateFunctions)
720         {
721             out << flatEvaluateFunction.functionDefinition << "\n";
722         }
723     }
724 
725     if (mUsesDiscardRewriting)
726     {
727         out << "#define ANGLE_USES_DISCARD_REWRITING\n";
728     }
729 
730     if (mUsesNestedBreak)
731     {
732         out << "#define ANGLE_USES_NESTED_BREAK\n";
733     }
734 
735     if (mRequiresIEEEStrictCompiling)
736     {
737         out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
738     }
739 
740     out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
741            "#define LOOP [loop]\n"
742            "#define FLATTEN [flatten]\n"
743            "#else\n"
744            "#define LOOP\n"
745            "#define FLATTEN\n"
746            "#endif\n";
747 
748     // array stride for atomic counter buffers is always 4 per original extension
749     // ARB_shader_atomic_counters and discussion on
750     // https://github.com/KhronosGroup/OpenGL-API/issues/5
751     out << "\n#define ATOMIC_COUNTER_ARRAY_STRIDE 4\n\n";
752 
753     if (mUseZeroArray)
754     {
755         out << DefineZeroArray() << "\n";
756     }
757 
758     if (mShaderType == GL_FRAGMENT_SHADER)
759     {
760         const bool usingMRTExtension =
761             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_draw_buffers);
762         const bool usingBFEExtension =
763             IsExtensionEnabled(mExtensionBehavior, TExtension::EXT_blend_func_extended);
764 
765         out << "// Varyings\n";
766         writeReferencedVaryings(out);
767         out << "\n";
768 
769         if (mShaderVersion >= 300)
770         {
771             for (const auto &outputVariable : mReferencedOutputVariables)
772             {
773                 const ImmutableString &variableName = outputVariable.second->name();
774                 const TType &variableType           = outputVariable.second->getType();
775 
776                 out << "static " << TypeString(variableType) << " out_" << variableName
777                     << ArrayString(variableType) << " = " << zeroInitializer(variableType) << ";\n";
778             }
779         }
780         else
781         {
782             const unsigned int numColorValues = usingMRTExtension ? mNumRenderTargets : 1;
783 
784             out << "static float4 gl_Color[" << numColorValues
785                 << "] =\n"
786                    "{\n";
787             for (unsigned int i = 0; i < numColorValues; i++)
788             {
789                 out << "    float4(0, 0, 0, 0)";
790                 if (i + 1 != numColorValues)
791                 {
792                     out << ",";
793                 }
794                 out << "\n";
795             }
796 
797             out << "};\n";
798 
799             if (usingBFEExtension && mUsesSecondaryColor)
800             {
801                 out << "static float4 gl_SecondaryColor[" << mMaxDualSourceDrawBuffers
802                     << "] = \n"
803                        "{\n";
804                 for (int i = 0; i < mMaxDualSourceDrawBuffers; i++)
805                 {
806                     out << "    float4(0, 0, 0, 0)";
807                     if (i + 1 != mMaxDualSourceDrawBuffers)
808                     {
809                         out << ",";
810                     }
811                     out << "\n";
812                 }
813                 out << "};\n";
814             }
815         }
816 
817         if (mUsesViewID)
818         {
819             out << "static uint ViewID_OVR = 0;\n";
820         }
821 
822         if (mUsesFragDepth)
823         {
824             out << "static float gl_Depth = 0.0;\n";
825         }
826 
827         if (mUsesSampleID)
828         {
829             out << "static int gl_SampleID = 0;\n";
830         }
831 
832         if (mUsesSamplePosition)
833         {
834             out << "static float2 gl_SamplePosition = float2(0.0, 0.0);\n";
835         }
836 
837         if (mUsesSampleMaskIn)
838         {
839             out << "static int gl_SampleMaskIn[1] = {0};\n";
840         }
841 
842         if (mUsesSampleMask)
843         {
844             out << "static int gl_SampleMask[1] = {0};\n";
845         }
846 
847         if (mUsesNumSamples)
848         {
849             out << "static int gl_NumSamples = GetRenderTargetSampleCount();\n";
850         }
851 
852         if (mUsesFragCoord)
853         {
854             out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
855         }
856 
857         if (mUsesPointCoord)
858         {
859             out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
860         }
861 
862         if (mUsesFrontFacing)
863         {
864             out << "static bool gl_FrontFacing = false;\n";
865         }
866 
867         if (mUsesHelperInvocation)
868         {
869             out << "static bool gl_HelperInvocation = false;\n";
870         }
871 
872         out << "\n";
873 
874         if (mUsesDepthRange)
875         {
876             out << "struct gl_DepthRangeParameters\n"
877                    "{\n"
878                    "    float near;\n"
879                    "    float far;\n"
880                    "    float diff;\n"
881                    "};\n"
882                    "\n";
883         }
884 
885         if (mOutputType == SH_HLSL_4_1_OUTPUT)
886         {
887             out << "cbuffer DriverConstants : register(b1)\n"
888                    "{\n";
889 
890             if (mUsesDepthRange)
891             {
892                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
893             }
894 
895             if (mUsesFragCoord)
896             {
897                 out << "    float4 dx_ViewCoords : packoffset(c1);\n";
898                 out << "    float2 dx_FragCoordOffset : packoffset(c3);\n";
899             }
900 
901             if (mUsesFragCoord || mUsesFrontFacing)
902             {
903                 out << "    float3 dx_DepthFront : packoffset(c2);\n";
904             }
905 
906             if (mUsesFragCoord)
907             {
908                 // dx_ViewScale is only used in the fragment shader to correct
909                 // the value for glFragCoord if necessary
910                 out << "    float2 dx_ViewScale : packoffset(c3.z);\n";
911             }
912 
913             if (mOutputType == SH_HLSL_4_1_OUTPUT)
914             {
915                 out << "    uint dx_Misc : packoffset(c2.w);\n";
916                 unsigned int registerIndex = 4;
917                 mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
918                 // Sampler metadata struct must be two 4-vec, 32 bytes.
919                 registerIndex += mResourcesHLSL->getSamplerCount() * 2;
920                 mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
921             }
922 
923             out << "};\n";
924 
925             if (mOutputType == SH_HLSL_4_1_OUTPUT && mResourcesHLSL->hasImages())
926             {
927                 out << kImage2DFunctionString << "\n";
928             }
929         }
930         else
931         {
932             if (mUsesDepthRange)
933             {
934                 out << "uniform float3 dx_DepthRange : register(c0);";
935             }
936 
937             if (mUsesFragCoord)
938             {
939                 out << "uniform float4 dx_ViewCoords : register(c1);\n";
940             }
941 
942             if (mUsesFragCoord || mUsesFrontFacing)
943             {
944                 out << "uniform float3 dx_DepthFront : register(c2);\n";
945                 out << "uniform float2 dx_FragCoordOffset : register(c3);\n";
946             }
947         }
948 
949         out << "\n";
950 
951         if (mUsesDepthRange)
952         {
953             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
954                    "dx_DepthRange.y, dx_DepthRange.z};\n"
955                    "\n";
956         }
957 
958         if (mClipDistanceSize)
959         {
960             out << "static float gl_ClipDistance[" << static_cast<int>(mClipDistanceSize)
961                 << "] = {0";
962             for (unsigned int i = 1; i < mClipDistanceSize; i++)
963             {
964                 out << ", 0";
965             }
966             out << "};\n";
967         }
968 
969         if (mCullDistanceSize)
970         {
971             out << "static float gl_CullDistance[" << static_cast<int>(mCullDistanceSize)
972                 << "] = {0";
973             for (unsigned int i = 1; i < mCullDistanceSize; i++)
974             {
975                 out << ", 0";
976             }
977             out << "};\n";
978         }
979 
980         if (usingMRTExtension && mNumRenderTargets > 1)
981         {
982             out << "#define GL_USES_MRT\n";
983         }
984 
985         if (mUsesFragColor)
986         {
987             out << "#define GL_USES_FRAG_COLOR\n";
988         }
989 
990         if (mUsesFragData)
991         {
992             out << "#define GL_USES_FRAG_DATA\n";
993         }
994 
995         if (mShaderVersion < 300 && usingBFEExtension && mUsesSecondaryColor)
996         {
997             out << "#define GL_USES_SECONDARY_COLOR\n";
998         }
999     }
1000     else if (mShaderType == GL_VERTEX_SHADER)
1001     {
1002         out << "// Attributes\n";
1003         writeReferencedAttributes(out);
1004         out << "\n"
1005                "static float4 gl_Position = float4(0, 0, 0, 0);\n";
1006 
1007         if (mClipDistanceSize)
1008         {
1009             out << "static float gl_ClipDistance[" << static_cast<int>(mClipDistanceSize)
1010                 << "] = {0";
1011             for (size_t i = 1; i < mClipDistanceSize; i++)
1012             {
1013                 out << ", 0";
1014             }
1015             out << "};\n";
1016         }
1017 
1018         if (mCullDistanceSize)
1019         {
1020             out << "static float gl_CullDistance[" << static_cast<int>(mCullDistanceSize)
1021                 << "] = {0";
1022             for (size_t i = 1; i < mCullDistanceSize; i++)
1023             {
1024                 out << ", 0";
1025             }
1026             out << "};\n";
1027         }
1028 
1029         if (mUsesPointSize)
1030         {
1031             out << "static float gl_PointSize = float(1);\n";
1032         }
1033 
1034         if (mUsesInstanceID)
1035         {
1036             out << "static int gl_InstanceID;\n";
1037         }
1038 
1039         if (mUsesViewID)
1040         {
1041             out << "static uint ViewID_OVR;\n";
1042         }
1043 
1044         if (mUsesVertexID)
1045         {
1046             out << "static int gl_VertexID;\n";
1047         }
1048 
1049         out << "\n"
1050                "// Varyings\n";
1051         writeReferencedVaryings(out);
1052         out << "\n";
1053 
1054         if (mUsesDepthRange)
1055         {
1056             out << "struct gl_DepthRangeParameters\n"
1057                    "{\n"
1058                    "    float near;\n"
1059                    "    float far;\n"
1060                    "    float diff;\n"
1061                    "};\n"
1062                    "\n";
1063         }
1064 
1065         if (mOutputType == SH_HLSL_4_1_OUTPUT)
1066         {
1067             out << "cbuffer DriverConstants : register(b1)\n"
1068                    "{\n";
1069 
1070             if (mUsesDepthRange)
1071             {
1072                 out << "    float3 dx_DepthRange : packoffset(c0);\n";
1073             }
1074 
1075             // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9
1076             // shaders. However, we declare it for all shaders (including Feature Level 10+).
1077             // The bytecode is the same whether we declare it or not, since D3DCompiler removes it
1078             // if it's unused.
1079             out << "    float4 dx_ViewAdjust : packoffset(c1);\n";
1080             out << "    float2 dx_ViewCoords : packoffset(c2);\n";
1081             out << "    float2 dx_ViewScale  : packoffset(c3);\n";
1082 
1083             out << "    float clipControlOrigin : packoffset(c3.z);\n";
1084             out << "    float clipControlZeroToOne : packoffset(c3.w);\n";
1085 
1086             if (mOutputType == SH_HLSL_4_1_OUTPUT)
1087             {
1088                 mResourcesHLSL->samplerMetadataUniforms(out, 5);
1089             }
1090 
1091             if (mUsesVertexID)
1092             {
1093                 out << "    uint dx_VertexID : packoffset(c4.x);\n";
1094             }
1095 
1096             if (mClipDistanceSize)
1097             {
1098                 out << "    uint clipDistancesEnabled : packoffset(c4.y);\n";
1099             }
1100 
1101             out << "};\n"
1102                    "\n";
1103         }
1104         else
1105         {
1106             if (mUsesDepthRange)
1107             {
1108                 out << "uniform float3 dx_DepthRange : register(c0);\n";
1109             }
1110 
1111             out << "uniform float4 dx_ViewAdjust : register(c1);\n";
1112             out << "uniform float2 dx_ViewCoords : register(c2);\n";
1113 
1114             out << "static const float clipControlOrigin = -1.0f;\n";
1115             out << "static const float clipControlZeroToOne = 0.0f;\n";
1116 
1117             out << "\n";
1118         }
1119 
1120         if (mUsesDepthRange)
1121         {
1122             out << "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, "
1123                    "dx_DepthRange.y, dx_DepthRange.z};\n"
1124                    "\n";
1125         }
1126 
1127         if (mOutputType == SH_HLSL_4_1_OUTPUT && mResourcesHLSL->hasImages())
1128         {
1129             out << kImage2DFunctionString << "\n";
1130         }
1131     }
1132     else  // Compute shader
1133     {
1134         ASSERT(mShaderType == GL_COMPUTE_SHADER);
1135 
1136         out << "cbuffer DriverConstants : register(b1)\n"
1137                "{\n";
1138         if (mUsesNumWorkGroups)
1139         {
1140             out << "    uint3 gl_NumWorkGroups : packoffset(c0);\n";
1141         }
1142         ASSERT(mOutputType == SH_HLSL_4_1_OUTPUT);
1143         unsigned int registerIndex = 1;
1144         mResourcesHLSL->samplerMetadataUniforms(out, registerIndex);
1145         // Sampler metadata struct must be two 4-vec, 32 bytes.
1146         registerIndex += mResourcesHLSL->getSamplerCount() * 2;
1147         mResourcesHLSL->imageMetadataUniforms(out, registerIndex);
1148         out << "};\n";
1149 
1150         out << kImage2DFunctionString << "\n";
1151 
1152         std::ostringstream systemValueDeclaration  = sh::InitializeStream<std::ostringstream>();
1153         std::ostringstream glBuiltinInitialization = sh::InitializeStream<std::ostringstream>();
1154 
1155         systemValueDeclaration << "\nstruct CS_INPUT\n{\n";
1156         glBuiltinInitialization << "\nvoid initGLBuiltins(CS_INPUT input)\n" << "{\n";
1157 
1158         if (mUsesWorkGroupID)
1159         {
1160             out << "static uint3 gl_WorkGroupID = uint3(0, 0, 0);\n";
1161             systemValueDeclaration << "    uint3 dx_WorkGroupID : " << "SV_GroupID;\n";
1162             glBuiltinInitialization << "    gl_WorkGroupID = input.dx_WorkGroupID;\n";
1163         }
1164 
1165         if (mUsesLocalInvocationID)
1166         {
1167             out << "static uint3 gl_LocalInvocationID = uint3(0, 0, 0);\n";
1168             systemValueDeclaration << "    uint3 dx_LocalInvocationID : " << "SV_GroupThreadID;\n";
1169             glBuiltinInitialization << "    gl_LocalInvocationID = input.dx_LocalInvocationID;\n";
1170         }
1171 
1172         if (mUsesGlobalInvocationID)
1173         {
1174             out << "static uint3 gl_GlobalInvocationID = uint3(0, 0, 0);\n";
1175             systemValueDeclaration << "    uint3 dx_GlobalInvocationID : "
1176                                    << "SV_DispatchThreadID;\n";
1177             glBuiltinInitialization << "    gl_GlobalInvocationID = input.dx_GlobalInvocationID;\n";
1178         }
1179 
1180         if (mUsesLocalInvocationIndex)
1181         {
1182             out << "static uint gl_LocalInvocationIndex = uint(0);\n";
1183             systemValueDeclaration << "    uint dx_LocalInvocationIndex : " << "SV_GroupIndex;\n";
1184             glBuiltinInitialization
1185                 << "    gl_LocalInvocationIndex = input.dx_LocalInvocationIndex;\n";
1186         }
1187 
1188         systemValueDeclaration << "};\n\n";
1189         glBuiltinInitialization << "};\n\n";
1190 
1191         out << systemValueDeclaration.str();
1192         out << glBuiltinInitialization.str();
1193     }
1194 
1195     if (!mappedStructs.empty())
1196     {
1197         out << "// Structures from std140 blocks with padding removed\n";
1198         out << "\n";
1199         out << mappedStructs;
1200         out << "\n";
1201     }
1202 
1203     bool getDimensionsIgnoresBaseLevel = mCompileOptions.HLSLGetDimensionsIgnoresBaseLevel;
1204     mTextureFunctionHLSL->textureFunctionHeader(out, mOutputType, getDimensionsIgnoresBaseLevel);
1205     mImageFunctionHLSL->imageFunctionHeader(out);
1206     mAtomicCounterFunctionHLSL->atomicCounterFunctionHeader(out);
1207 
1208     if (mUsesFragCoord)
1209     {
1210         out << "#define GL_USES_FRAG_COORD\n";
1211     }
1212 
1213     if (mUsesPointCoord)
1214     {
1215         out << "#define GL_USES_POINT_COORD\n";
1216     }
1217 
1218     if (mUsesFrontFacing)
1219     {
1220         out << "#define GL_USES_FRONT_FACING\n";
1221     }
1222 
1223     if (mUsesHelperInvocation)
1224     {
1225         out << "#define GL_USES_HELPER_INVOCATION\n";
1226     }
1227 
1228     if (mUsesPointSize)
1229     {
1230         out << "#define GL_USES_POINT_SIZE\n";
1231     }
1232 
1233     if (mHasMultiviewExtensionEnabled)
1234     {
1235         out << "#define GL_MULTIVIEW_ENABLED\n";
1236     }
1237 
1238     if (mUsesVertexID)
1239     {
1240         out << "#define GL_USES_VERTEX_ID\n";
1241     }
1242 
1243     if (mUsesViewID)
1244     {
1245         out << "#define GL_USES_VIEW_ID\n";
1246     }
1247 
1248     if (mUsesSampleID)
1249     {
1250         out << "#define GL_USES_SAMPLE_ID\n";
1251     }
1252 
1253     if (mUsesSamplePosition)
1254     {
1255         out << "#define GL_USES_SAMPLE_POSITION\n";
1256     }
1257 
1258     if (mUsesSampleMaskIn)
1259     {
1260         out << "#define GL_USES_SAMPLE_MASK_IN\n";
1261     }
1262 
1263     if (mUsesSampleMask)
1264     {
1265         out << "#define GL_USES_SAMPLE_MASK_OUT\n";
1266     }
1267 
1268     if (mUsesFragDepth)
1269     {
1270         switch (mDepthLayout)
1271         {
1272             case EdGreater:
1273                 out << "#define GL_USES_FRAG_DEPTH_GREATER\n";
1274                 break;
1275             case EdLess:
1276                 out << "#define GL_USES_FRAG_DEPTH_LESS\n";
1277                 break;
1278             default:
1279                 out << "#define GL_USES_FRAG_DEPTH\n";
1280                 break;
1281         }
1282     }
1283 
1284     if (mUsesDepthRange)
1285     {
1286         out << "#define GL_USES_DEPTH_RANGE\n";
1287     }
1288 
1289     if (mUsesXor)
1290     {
1291         out << "bool xor(bool p, bool q)\n"
1292                "{\n"
1293                "    return (p || q) && !(p && q);\n"
1294                "}\n"
1295                "\n";
1296     }
1297 
1298     builtInFunctionEmulator->outputEmulatedFunctions(out);
1299 }
1300 
visitSymbol(TIntermSymbol * node)1301 void OutputHLSL::visitSymbol(TIntermSymbol *node)
1302 {
1303     const TVariable &variable = node->variable();
1304 
1305     // Empty symbols can only appear in declarations and function arguments, and in either of those
1306     // cases the symbol nodes are not visited.
1307     ASSERT(variable.symbolType() != SymbolType::Empty);
1308 
1309     TInfoSinkBase &out = getInfoSink();
1310 
1311     // Handle accessing std140 structs by value
1312     if (IsInStd140UniformBlock(node) && node->getBasicType() == EbtStruct &&
1313         needStructMapping(node))
1314     {
1315         mNeedStructMapping = true;
1316         out << "map";
1317     }
1318 
1319     const ImmutableString &name     = variable.name();
1320     const TSymbolUniqueId &uniqueId = variable.uniqueId();
1321 
1322     if (name == "gl_DepthRange")
1323     {
1324         mUsesDepthRange = true;
1325         out << name;
1326     }
1327     else if (name == "gl_NumSamples")
1328     {
1329         mUsesNumSamples = true;
1330         out << name;
1331     }
1332     else if (IsAtomicCounter(variable.getType().getBasicType()))
1333     {
1334         const TType &variableType = variable.getType();
1335         if (variableType.getQualifier() == EvqUniform)
1336         {
1337             TLayoutQualifier layout             = variableType.getLayoutQualifier();
1338             mReferencedUniforms[uniqueId.get()] = &variable;
1339             out << getAtomicCounterNameForBinding(layout.binding) << ", " << layout.offset;
1340         }
1341         else
1342         {
1343             TString varName = DecorateVariableIfNeeded(variable);
1344             out << varName << ", " << varName << "_offset";
1345         }
1346     }
1347     else
1348     {
1349         const TType &variableType = variable.getType();
1350         TQualifier qualifier      = variable.getType().getQualifier();
1351 
1352         ensureStructDefined(variableType);
1353 
1354         if (qualifier == EvqUniform)
1355         {
1356             const TInterfaceBlock *interfaceBlock = variableType.getInterfaceBlock();
1357 
1358             if (interfaceBlock)
1359             {
1360                 if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1361                 {
1362                     const TVariable *instanceVariable = nullptr;
1363                     if (variableType.isInterfaceBlock())
1364                     {
1365                         instanceVariable = &variable;
1366                     }
1367                     mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1368                         new TReferencedBlock(interfaceBlock, instanceVariable);
1369                 }
1370             }
1371             else
1372             {
1373                 mReferencedUniforms[uniqueId.get()] = &variable;
1374             }
1375 
1376             out << DecorateVariableIfNeeded(variable);
1377         }
1378         else if (qualifier == EvqBuffer)
1379         {
1380             UNREACHABLE();
1381         }
1382         else if (qualifier == EvqAttribute || qualifier == EvqVertexIn)
1383         {
1384             mReferencedAttributes[uniqueId.get()] = &variable;
1385             out << Decorate(name);
1386         }
1387         else if (IsVarying(qualifier))
1388         {
1389             mReferencedVaryings[uniqueId.get()] = &variable;
1390             out << DecorateVariableIfNeeded(variable);
1391         }
1392         else if (qualifier == EvqFragmentOut)
1393         {
1394             mReferencedOutputVariables[uniqueId.get()] = &variable;
1395             out << "out_" << name;
1396         }
1397         else if (qualifier == EvqViewIDOVR)
1398         {
1399             out << name;
1400             mUsesViewID = true;
1401         }
1402         else if (qualifier == EvqClipDistance)
1403         {
1404             out << name;
1405         }
1406         else if (qualifier == EvqCullDistance)
1407         {
1408             out << name;
1409         }
1410         else if (qualifier == EvqFragColor)
1411         {
1412             out << "gl_Color[0]";
1413             mUsesFragColor = true;
1414         }
1415         else if (qualifier == EvqFragData)
1416         {
1417             out << "gl_Color";
1418             mUsesFragData = true;
1419         }
1420         else if (qualifier == EvqSecondaryFragColorEXT)
1421         {
1422             out << "gl_SecondaryColor[0]";
1423             mUsesSecondaryColor = true;
1424         }
1425         else if (qualifier == EvqSecondaryFragDataEXT)
1426         {
1427             out << "gl_SecondaryColor";
1428             mUsesSecondaryColor = true;
1429         }
1430         else if (qualifier == EvqFragCoord)
1431         {
1432             mUsesFragCoord = true;
1433             out << name;
1434         }
1435         else if (qualifier == EvqPointCoord)
1436         {
1437             mUsesPointCoord = true;
1438             out << name;
1439         }
1440         else if (qualifier == EvqFrontFacing)
1441         {
1442             mUsesFrontFacing = true;
1443             out << name;
1444         }
1445         else if (qualifier == EvqHelperInvocation)
1446         {
1447             mUsesHelperInvocation = true;
1448             out << name;
1449         }
1450         else if (qualifier == EvqPointSize)
1451         {
1452             mUsesPointSize = true;
1453             out << name;
1454         }
1455         else if (qualifier == EvqInstanceID)
1456         {
1457             mUsesInstanceID = true;
1458             out << name;
1459         }
1460         else if (qualifier == EvqVertexID)
1461         {
1462             mUsesVertexID = true;
1463             out << name;
1464         }
1465         else if (name == "gl_FragDepthEXT" || name == "gl_FragDepth")
1466         {
1467             mUsesFragDepth = true;
1468             mDepthLayout   = variableType.getLayoutQualifier().depth;
1469             out << "gl_Depth";
1470         }
1471         else if (qualifier == EvqSampleID)
1472         {
1473             mUsesSampleID = true;
1474             out << name;
1475         }
1476         else if (qualifier == EvqSamplePosition)
1477         {
1478             mUsesSamplePosition = true;
1479             out << name;
1480         }
1481         else if (qualifier == EvqSampleMaskIn)
1482         {
1483             mUsesSampleMaskIn = true;
1484             out << name;
1485         }
1486         else if (qualifier == EvqSampleMask)
1487         {
1488             mUsesSampleMask = true;
1489             out << name;
1490         }
1491         else if (qualifier == EvqNumWorkGroups)
1492         {
1493             mUsesNumWorkGroups = true;
1494             out << name;
1495         }
1496         else if (qualifier == EvqWorkGroupID)
1497         {
1498             mUsesWorkGroupID = true;
1499             out << name;
1500         }
1501         else if (qualifier == EvqLocalInvocationID)
1502         {
1503             mUsesLocalInvocationID = true;
1504             out << name;
1505         }
1506         else if (qualifier == EvqGlobalInvocationID)
1507         {
1508             mUsesGlobalInvocationID = true;
1509             out << name;
1510         }
1511         else if (qualifier == EvqLocalInvocationIndex)
1512         {
1513             mUsesLocalInvocationIndex = true;
1514             out << name;
1515         }
1516         else
1517         {
1518             out << DecorateVariableIfNeeded(variable);
1519         }
1520     }
1521 }
1522 
outputEqual(Visit visit,const TType & type,TOperator op,TInfoSinkBase & out)1523 void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
1524 {
1525     if (type.isScalar() && !type.isArray())
1526     {
1527         if (op == EOpEqual)
1528         {
1529             outputTriplet(out, visit, "(", " == ", ")");
1530         }
1531         else
1532         {
1533             outputTriplet(out, visit, "(", " != ", ")");
1534         }
1535     }
1536     else
1537     {
1538         if (visit == PreVisit && op == EOpNotEqual)
1539         {
1540             out << "!";
1541         }
1542 
1543         if (type.isArray())
1544         {
1545             const TString &functionName = addArrayEqualityFunction(type);
1546             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1547         }
1548         else if (type.getBasicType() == EbtStruct)
1549         {
1550             const TStructure &structure = *type.getStruct();
1551             const TString &functionName = addStructEqualityFunction(structure);
1552             outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1553         }
1554         else
1555         {
1556             ASSERT(type.isMatrix() || type.isVector());
1557             outputTriplet(out, visit, "all(", " == ", ")");
1558         }
1559     }
1560 }
1561 
outputAssign(Visit visit,const TType & type,TInfoSinkBase & out)1562 void OutputHLSL::outputAssign(Visit visit, const TType &type, TInfoSinkBase &out)
1563 {
1564     if (type.isArray())
1565     {
1566         const TString &functionName = addArrayAssignmentFunction(type);
1567         outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
1568     }
1569     else
1570     {
1571         outputTriplet(out, visit, "(", " = ", ")");
1572     }
1573 }
1574 
ancestorEvaluatesToSamplerInStruct()1575 bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
1576 {
1577     for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
1578     {
1579         TIntermNode *ancestor               = getAncestorNode(n);
1580         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
1581         if (ancestorBinary == nullptr)
1582         {
1583             return false;
1584         }
1585         switch (ancestorBinary->getOp())
1586         {
1587             case EOpIndexDirectStruct:
1588             {
1589                 const TStructure *structure = ancestorBinary->getLeft()->getType().getStruct();
1590                 const TIntermConstantUnion *index =
1591                     ancestorBinary->getRight()->getAsConstantUnion();
1592                 const TField *field = structure->fields()[index->getIConst(0)];
1593                 if (IsSampler(field->type()->getBasicType()))
1594                 {
1595                     return true;
1596                 }
1597                 break;
1598             }
1599             case EOpIndexDirect:
1600                 break;
1601             default:
1602                 // Returning a sampler from indirect indexing is not supported.
1603                 return false;
1604         }
1605     }
1606     return false;
1607 }
1608 
visitSwizzle(Visit visit,TIntermSwizzle * node)1609 bool OutputHLSL::visitSwizzle(Visit visit, TIntermSwizzle *node)
1610 {
1611     TInfoSinkBase &out = getInfoSink();
1612     if (visit == PostVisit)
1613     {
1614         out << ".";
1615         node->writeOffsetsAsXYZW(&out);
1616     }
1617     return true;
1618 }
1619 
visitBinary(Visit visit,TIntermBinary * node)1620 bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
1621 {
1622     TInfoSinkBase &out = getInfoSink();
1623 
1624     switch (node->getOp())
1625     {
1626         case EOpComma:
1627             outputTriplet(out, visit, "(", ", ", ")");
1628             break;
1629         case EOpAssign:
1630             if (node->isArray())
1631             {
1632                 TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
1633                 if (rightAgg != nullptr && rightAgg->isConstructor())
1634                 {
1635                     const TString &functionName = addArrayConstructIntoFunction(node->getType());
1636                     out << functionName << "(";
1637                     node->getLeft()->traverse(this);
1638                     TIntermSequence *seq = rightAgg->getSequence();
1639                     for (auto &arrayElement : *seq)
1640                     {
1641                         out << ", ";
1642                         arrayElement->traverse(this);
1643                     }
1644                     out << ")";
1645                     return false;
1646                 }
1647                 // ArrayReturnValueToOutParameter should have eliminated expressions where a
1648                 // function call is assigned.
1649                 ASSERT(rightAgg == nullptr);
1650             }
1651             // Assignment expressions with atomic functions should be transformed into atomic
1652             // function calls in HLSL.
1653             // e.g. original_value = atomicAdd(dest, value) should be translated into
1654             //      InterlockedAdd(dest, value, original_value);
1655             else if (IsAtomicFunctionForSharedVariableDirectAssign(*node))
1656             {
1657                 TIntermAggregate *atomicFunctionNode = node->getRight()->getAsAggregate();
1658                 TOperator atomicFunctionOp           = atomicFunctionNode->getOp();
1659                 out << GetHLSLAtomicFunctionStringAndLeftParenthesis(atomicFunctionOp);
1660                 TIntermSequence *argumentSeq = atomicFunctionNode->getSequence();
1661                 ASSERT(argumentSeq->size() >= 2u);
1662                 for (auto &argument : *argumentSeq)
1663                 {
1664                     argument->traverse(this);
1665                     out << ", ";
1666                 }
1667                 node->getLeft()->traverse(this);
1668                 out << ")";
1669                 return false;
1670             }
1671             else if (IsInShaderStorageBlock(node->getLeft()))
1672             {
1673                 mSSBOOutputHLSL->outputStoreFunctionCallPrefix(node->getLeft());
1674                 out << ", ";
1675                 if (IsInShaderStorageBlock(node->getRight()))
1676                 {
1677                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1678                 }
1679                 else
1680                 {
1681                     node->getRight()->traverse(this);
1682                 }
1683 
1684                 out << ")";
1685                 return false;
1686             }
1687             else if (IsInShaderStorageBlock(node->getRight()))
1688             {
1689                 node->getLeft()->traverse(this);
1690                 out << " = ";
1691                 mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1692                 return false;
1693             }
1694 
1695             outputAssign(visit, node->getType(), out);
1696             break;
1697         case EOpInitialize:
1698             if (visit == PreVisit)
1699             {
1700                 TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
1701                 ASSERT(symbolNode);
1702                 TIntermTyped *initializer = node->getRight();
1703 
1704                 // Global initializers must be constant at this point.
1705                 ASSERT(symbolNode->getQualifier() != EvqGlobal || initializer->hasConstantValue());
1706 
1707                 // GLSL allows to write things like "float x = x;" where a new variable x is defined
1708                 // and the value of an existing variable x is assigned. HLSL uses C semantics (the
1709                 // new variable is created before the assignment is evaluated), so we need to
1710                 // convert
1711                 // this to "float t = x, x = t;".
1712                 if (writeSameSymbolInitializer(out, symbolNode, initializer))
1713                 {
1714                     // Skip initializing the rest of the expression
1715                     return false;
1716                 }
1717                 else if (writeConstantInitialization(out, symbolNode, initializer))
1718                 {
1719                     return false;
1720                 }
1721             }
1722             else if (visit == InVisit)
1723             {
1724                 out << " = ";
1725                 if (IsInShaderStorageBlock(node->getRight()))
1726                 {
1727                     mSSBOOutputHLSL->outputLoadFunctionCall(node->getRight());
1728                     return false;
1729                 }
1730             }
1731             break;
1732         case EOpAddAssign:
1733             outputTriplet(out, visit, "(", " += ", ")");
1734             break;
1735         case EOpSubAssign:
1736             outputTriplet(out, visit, "(", " -= ", ")");
1737             break;
1738         case EOpMulAssign:
1739             outputTriplet(out, visit, "(", " *= ", ")");
1740             break;
1741         case EOpVectorTimesScalarAssign:
1742             outputTriplet(out, visit, "(", " *= ", ")");
1743             break;
1744         case EOpMatrixTimesScalarAssign:
1745             outputTriplet(out, visit, "(", " *= ", ")");
1746             break;
1747         case EOpVectorTimesMatrixAssign:
1748             if (visit == PreVisit)
1749             {
1750                 out << "(";
1751             }
1752             else if (visit == InVisit)
1753             {
1754                 out << " = mul(";
1755                 node->getLeft()->traverse(this);
1756                 out << ", transpose(";
1757             }
1758             else
1759             {
1760                 out << ")))";
1761             }
1762             break;
1763         case EOpMatrixTimesMatrixAssign:
1764             if (visit == PreVisit)
1765             {
1766                 out << "(";
1767             }
1768             else if (visit == InVisit)
1769             {
1770                 out << " = transpose(mul(transpose(";
1771                 node->getLeft()->traverse(this);
1772                 out << "), transpose(";
1773             }
1774             else
1775             {
1776                 out << "))))";
1777             }
1778             break;
1779         case EOpDivAssign:
1780             outputTriplet(out, visit, "(", " /= ", ")");
1781             break;
1782         case EOpIModAssign:
1783             outputTriplet(out, visit, "(", " %= ", ")");
1784             break;
1785         case EOpBitShiftLeftAssign:
1786             outputTriplet(out, visit, "(", " <<= ", ")");
1787             break;
1788         case EOpBitShiftRightAssign:
1789             outputTriplet(out, visit, "(", " >>= ", ")");
1790             break;
1791         case EOpBitwiseAndAssign:
1792             outputTriplet(out, visit, "(", " &= ", ")");
1793             break;
1794         case EOpBitwiseXorAssign:
1795             outputTriplet(out, visit, "(", " ^= ", ")");
1796             break;
1797         case EOpBitwiseOrAssign:
1798             outputTriplet(out, visit, "(", " |= ", ")");
1799             break;
1800         case EOpIndexDirect:
1801         {
1802             const TType &leftType = node->getLeft()->getType();
1803             if (leftType.isInterfaceBlock())
1804             {
1805                 if (visit == PreVisit)
1806                 {
1807                     TIntermSymbol *instanceArraySymbol    = node->getLeft()->getAsSymbolNode();
1808                     const TInterfaceBlock *interfaceBlock = leftType.getInterfaceBlock();
1809 
1810                     ASSERT(leftType.getQualifier() == EvqUniform);
1811                     if (mReferencedUniformBlocks.count(interfaceBlock->uniqueId().get()) == 0)
1812                     {
1813                         mReferencedUniformBlocks[interfaceBlock->uniqueId().get()] =
1814                             new TReferencedBlock(interfaceBlock, &instanceArraySymbol->variable());
1815                     }
1816                     const int arrayIndex = node->getRight()->getAsConstantUnion()->getIConst(0);
1817                     out << mResourcesHLSL->InterfaceBlockInstanceString(
1818                         instanceArraySymbol->getName(), arrayIndex);
1819                     return false;
1820                 }
1821             }
1822             else if (ancestorEvaluatesToSamplerInStruct())
1823             {
1824                 // All parts of an expression that access a sampler in a struct need to use _ as
1825                 // separator to access the sampler variable that has been moved out of the struct.
1826                 outputTriplet(out, visit, "", "_", "");
1827             }
1828             else if (IsAtomicCounter(leftType.getBasicType()))
1829             {
1830                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1831             }
1832             else
1833             {
1834                 outputTriplet(out, visit, "", "[", "]");
1835                 if (visit == PostVisit)
1836                 {
1837                     const TInterfaceBlock *interfaceBlock =
1838                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1839                     if (interfaceBlock &&
1840                         mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1841                     {
1842                         // If the uniform block member's type is not structure, we had explicitly
1843                         // packed the member into a structure, so need to add an operator of field
1844                         // slection.
1845                         const TField *field    = interfaceBlock->fields()[0];
1846                         const TType *fieldType = field->type();
1847                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1848                             fieldType->isScalarArray())
1849                         {
1850                             out << "." << Decorate(field->name());
1851                         }
1852                     }
1853                 }
1854             }
1855         }
1856         break;
1857         case EOpIndexIndirect:
1858         {
1859             // We do not currently support indirect references to interface blocks
1860             ASSERT(node->getLeft()->getBasicType() != EbtInterfaceBlock);
1861 
1862             const TType &leftType = node->getLeft()->getType();
1863             if (IsAtomicCounter(leftType.getBasicType()))
1864             {
1865                 outputTriplet(out, visit, "", " + (", ") * ATOMIC_COUNTER_ARRAY_STRIDE");
1866             }
1867             else
1868             {
1869                 outputTriplet(out, visit, "", "[", "]");
1870                 if (visit == PostVisit)
1871                 {
1872                     const TInterfaceBlock *interfaceBlock =
1873                         GetInterfaceBlockOfUniformBlockNearestIndexOperator(node->getLeft());
1874                     if (interfaceBlock &&
1875                         mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1876                     {
1877                         // If the uniform block member's type is not structure, we had explicitly
1878                         // packed the member into a structure, so need to add an operator of field
1879                         // slection.
1880                         const TField *field    = interfaceBlock->fields()[0];
1881                         const TType *fieldType = field->type();
1882                         if (fieldType->isMatrix() || fieldType->isVectorArray() ||
1883                             fieldType->isScalarArray())
1884                         {
1885                             out << "." << Decorate(field->name());
1886                         }
1887                     }
1888                 }
1889             }
1890             break;
1891         }
1892         case EOpIndexDirectStruct:
1893         {
1894             const TStructure *structure       = node->getLeft()->getType().getStruct();
1895             const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1896             const TField *field               = structure->fields()[index->getIConst(0)];
1897 
1898             // In cases where indexing returns a sampler, we need to access the sampler variable
1899             // that has been moved out of the struct.
1900             bool indexingReturnsSampler = IsSampler(field->type()->getBasicType());
1901             if (visit == PreVisit && indexingReturnsSampler)
1902             {
1903                 // Samplers extracted from structs have "angle" prefix to avoid name conflicts.
1904                 // This prefix is only output at the beginning of the indexing expression, which
1905                 // may have multiple parts.
1906                 out << "angle";
1907             }
1908             if (!indexingReturnsSampler)
1909             {
1910                 // All parts of an expression that access a sampler in a struct need to use _ as
1911                 // separator to access the sampler variable that has been moved out of the struct.
1912                 indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
1913             }
1914             if (visit == InVisit)
1915             {
1916                 if (indexingReturnsSampler)
1917                 {
1918                     out << "_" << field->name();
1919                 }
1920                 else
1921                 {
1922                     out << "." << DecorateField(field->name(), *structure);
1923                 }
1924 
1925                 return false;
1926             }
1927         }
1928         break;
1929         case EOpIndexDirectInterfaceBlock:
1930         {
1931             ASSERT(!IsInShaderStorageBlock(node->getLeft()));
1932             bool structInStd140UniformBlock = node->getBasicType() == EbtStruct &&
1933                                               IsInStd140UniformBlock(node->getLeft()) &&
1934                                               needStructMapping(node);
1935             if (visit == PreVisit && structInStd140UniformBlock)
1936             {
1937                 mNeedStructMapping = true;
1938                 out << "map";
1939             }
1940             if (visit == InVisit)
1941             {
1942                 const TInterfaceBlock *interfaceBlock =
1943                     node->getLeft()->getType().getInterfaceBlock();
1944                 const TIntermConstantUnion *index = node->getRight()->getAsConstantUnion();
1945                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
1946                 if (structInStd140UniformBlock ||
1947                     mUniformBlockOptimizedMap.count(interfaceBlock->uniqueId().get()) != 0)
1948                 {
1949                     out << "_";
1950                 }
1951                 else
1952                 {
1953                     out << ".";
1954                 }
1955                 out << Decorate(field->name());
1956 
1957                 return false;
1958             }
1959             break;
1960         }
1961         case EOpAdd:
1962             outputTriplet(out, visit, "(", " + ", ")");
1963             break;
1964         case EOpSub:
1965             outputTriplet(out, visit, "(", " - ", ")");
1966             break;
1967         case EOpMul:
1968             outputTriplet(out, visit, "(", " * ", ")");
1969             break;
1970         case EOpDiv:
1971             outputTriplet(out, visit, "(", " / ", ")");
1972             break;
1973         case EOpIMod:
1974             outputTriplet(out, visit, "(", " % ", ")");
1975             break;
1976         case EOpBitShiftLeft:
1977             outputTriplet(out, visit, "(", " << ", ")");
1978             break;
1979         case EOpBitShiftRight:
1980             outputTriplet(out, visit, "(", " >> ", ")");
1981             break;
1982         case EOpBitwiseAnd:
1983             outputTriplet(out, visit, "(", " & ", ")");
1984             break;
1985         case EOpBitwiseXor:
1986             outputTriplet(out, visit, "(", " ^ ", ")");
1987             break;
1988         case EOpBitwiseOr:
1989             outputTriplet(out, visit, "(", " | ", ")");
1990             break;
1991         case EOpEqual:
1992         case EOpNotEqual:
1993             outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
1994             break;
1995         case EOpLessThan:
1996             outputTriplet(out, visit, "(", " < ", ")");
1997             break;
1998         case EOpGreaterThan:
1999             outputTriplet(out, visit, "(", " > ", ")");
2000             break;
2001         case EOpLessThanEqual:
2002             outputTriplet(out, visit, "(", " <= ", ")");
2003             break;
2004         case EOpGreaterThanEqual:
2005             outputTriplet(out, visit, "(", " >= ", ")");
2006             break;
2007         case EOpVectorTimesScalar:
2008             outputTriplet(out, visit, "(", " * ", ")");
2009             break;
2010         case EOpMatrixTimesScalar:
2011             outputTriplet(out, visit, "(", " * ", ")");
2012             break;
2013         case EOpVectorTimesMatrix:
2014             outputTriplet(out, visit, "mul(", ", transpose(", "))");
2015             break;
2016         case EOpMatrixTimesVector:
2017             outputTriplet(out, visit, "mul(transpose(", "), ", ")");
2018             break;
2019         case EOpMatrixTimesMatrix:
2020             outputTriplet(out, visit, "transpose(mul(transpose(", "), transpose(", ")))");
2021             break;
2022         case EOpLogicalOr:
2023             // HLSL doesn't short-circuit ||, so we assume that || affected by short-circuiting have
2024             // been unfolded.
2025             ASSERT(!node->getRight()->hasSideEffects());
2026             outputTriplet(out, visit, "(", " || ", ")");
2027             return true;
2028         case EOpLogicalXor:
2029             mUsesXor = true;
2030             outputTriplet(out, visit, "xor(", ", ", ")");
2031             break;
2032         case EOpLogicalAnd:
2033             // HLSL doesn't short-circuit &&, so we assume that && affected by short-circuiting have
2034             // been unfolded.
2035             ASSERT(!node->getRight()->hasSideEffects());
2036             outputTriplet(out, visit, "(", " && ", ")");
2037             return true;
2038         default:
2039             UNREACHABLE();
2040     }
2041 
2042     return true;
2043 }
2044 
visitUnary(Visit visit,TIntermUnary * node)2045 bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
2046 {
2047     TInfoSinkBase &out = getInfoSink();
2048 
2049     switch (node->getOp())
2050     {
2051         case EOpNegative:
2052             outputTriplet(out, visit, "(-", "", ")");
2053             break;
2054         case EOpPositive:
2055             outputTriplet(out, visit, "(+", "", ")");
2056             break;
2057         case EOpLogicalNot:
2058             outputTriplet(out, visit, "(!", "", ")");
2059             break;
2060         case EOpBitwiseNot:
2061             outputTriplet(out, visit, "(~", "", ")");
2062             break;
2063         case EOpPostIncrement:
2064             outputTriplet(out, visit, "(", "", "++)");
2065             break;
2066         case EOpPostDecrement:
2067             outputTriplet(out, visit, "(", "", "--)");
2068             break;
2069         case EOpPreIncrement:
2070             outputTriplet(out, visit, "(++", "", ")");
2071             break;
2072         case EOpPreDecrement:
2073             outputTriplet(out, visit, "(--", "", ")");
2074             break;
2075         case EOpRadians:
2076             outputTriplet(out, visit, "radians(", "", ")");
2077             break;
2078         case EOpDegrees:
2079             outputTriplet(out, visit, "degrees(", "", ")");
2080             break;
2081         case EOpSin:
2082             outputTriplet(out, visit, "sin(", "", ")");
2083             break;
2084         case EOpCos:
2085             outputTriplet(out, visit, "cos(", "", ")");
2086             break;
2087         case EOpTan:
2088             outputTriplet(out, visit, "tan(", "", ")");
2089             break;
2090         case EOpAsin:
2091             outputTriplet(out, visit, "asin(", "", ")");
2092             break;
2093         case EOpAcos:
2094             outputTriplet(out, visit, "acos(", "", ")");
2095             break;
2096         case EOpAtan:
2097             outputTriplet(out, visit, "atan(", "", ")");
2098             break;
2099         case EOpSinh:
2100             outputTriplet(out, visit, "sinh(", "", ")");
2101             break;
2102         case EOpCosh:
2103             outputTriplet(out, visit, "cosh(", "", ")");
2104             break;
2105         case EOpTanh:
2106         case EOpAsinh:
2107         case EOpAcosh:
2108         case EOpAtanh:
2109             ASSERT(node->getUseEmulatedFunction());
2110             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2111             break;
2112         case EOpExp:
2113             outputTriplet(out, visit, "exp(", "", ")");
2114             break;
2115         case EOpLog:
2116             outputTriplet(out, visit, "log(", "", ")");
2117             break;
2118         case EOpExp2:
2119             outputTriplet(out, visit, "exp2(", "", ")");
2120             break;
2121         case EOpLog2:
2122             outputTriplet(out, visit, "log2(", "", ")");
2123             break;
2124         case EOpSqrt:
2125             outputTriplet(out, visit, "sqrt(", "", ")");
2126             break;
2127         case EOpInversesqrt:
2128             outputTriplet(out, visit, "rsqrt(", "", ")");
2129             break;
2130         case EOpAbs:
2131             outputTriplet(out, visit, "abs(", "", ")");
2132             break;
2133         case EOpSign:
2134             outputTriplet(out, visit, "sign(", "", ")");
2135             break;
2136         case EOpFloor:
2137             outputTriplet(out, visit, "floor(", "", ")");
2138             break;
2139         case EOpTrunc:
2140             outputTriplet(out, visit, "trunc(", "", ")");
2141             break;
2142         case EOpRound:
2143             outputTriplet(out, visit, "round(", "", ")");
2144             break;
2145         case EOpRoundEven:
2146             ASSERT(node->getUseEmulatedFunction());
2147             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2148             break;
2149         case EOpCeil:
2150             outputTriplet(out, visit, "ceil(", "", ")");
2151             break;
2152         case EOpFract:
2153             outputTriplet(out, visit, "frac(", "", ")");
2154             break;
2155         case EOpIsnan:
2156             if (node->getUseEmulatedFunction())
2157                 writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2158             else
2159                 outputTriplet(out, visit, "isnan(", "", ")");
2160             mRequiresIEEEStrictCompiling = true;
2161             break;
2162         case EOpIsinf:
2163             outputTriplet(out, visit, "isinf(", "", ")");
2164             break;
2165         case EOpFloatBitsToInt:
2166             outputTriplet(out, visit, "asint(", "", ")");
2167             break;
2168         case EOpFloatBitsToUint:
2169             outputTriplet(out, visit, "asuint(", "", ")");
2170             break;
2171         case EOpIntBitsToFloat:
2172             outputTriplet(out, visit, "asfloat(", "", ")");
2173             break;
2174         case EOpUintBitsToFloat:
2175             outputTriplet(out, visit, "asfloat(", "", ")");
2176             break;
2177         case EOpPackSnorm2x16:
2178         case EOpPackUnorm2x16:
2179         case EOpPackHalf2x16:
2180         case EOpUnpackSnorm2x16:
2181         case EOpUnpackUnorm2x16:
2182         case EOpUnpackHalf2x16:
2183         case EOpPackUnorm4x8:
2184         case EOpPackSnorm4x8:
2185         case EOpUnpackUnorm4x8:
2186         case EOpUnpackSnorm4x8:
2187             ASSERT(node->getUseEmulatedFunction());
2188             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2189             break;
2190         case EOpLength:
2191             outputTriplet(out, visit, "length(", "", ")");
2192             break;
2193         case EOpNormalize:
2194             outputTriplet(out, visit, "normalize(", "", ")");
2195             break;
2196         case EOpTranspose:
2197             outputTriplet(out, visit, "transpose(", "", ")");
2198             break;
2199         case EOpDeterminant:
2200             outputTriplet(out, visit, "determinant(transpose(", "", "))");
2201             break;
2202         case EOpInverse:
2203             ASSERT(node->getUseEmulatedFunction());
2204             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2205             break;
2206 
2207         case EOpAny:
2208             outputTriplet(out, visit, "any(", "", ")");
2209             break;
2210         case EOpAll:
2211             outputTriplet(out, visit, "all(", "", ")");
2212             break;
2213         case EOpNotComponentWise:
2214             outputTriplet(out, visit, "(!", "", ")");
2215             break;
2216         case EOpBitfieldReverse:
2217             outputTriplet(out, visit, "reversebits(", "", ")");
2218             break;
2219         case EOpBitCount:
2220             outputTriplet(out, visit, "countbits(", "", ")");
2221             break;
2222         case EOpFindLSB:
2223             // Note that it's unclear from the HLSL docs what this returns for 0, but this is tested
2224             // in GLSLTest and results are consistent with GL.
2225             outputTriplet(out, visit, "firstbitlow(", "", ")");
2226             break;
2227         case EOpFindMSB:
2228             // Note that it's unclear from the HLSL docs what this returns for 0 or -1, but this is
2229             // tested in GLSLTest and results are consistent with GL.
2230             outputTriplet(out, visit, "firstbithigh(", "", ")");
2231             break;
2232         case EOpArrayLength:
2233         {
2234             TIntermTyped *operand = node->getOperand();
2235             ASSERT(IsInShaderStorageBlock(operand));
2236             mSSBOOutputHLSL->outputLengthFunctionCall(operand);
2237             return false;
2238         }
2239         default:
2240             UNREACHABLE();
2241     }
2242 
2243     return true;
2244 }
2245 
samplerNamePrefixFromStruct(TIntermTyped * node)2246 ImmutableString OutputHLSL::samplerNamePrefixFromStruct(TIntermTyped *node)
2247 {
2248     if (node->getAsSymbolNode())
2249     {
2250         ASSERT(node->getAsSymbolNode()->variable().symbolType() != SymbolType::Empty);
2251         return node->getAsSymbolNode()->getName();
2252     }
2253     TIntermBinary *nodeBinary = node->getAsBinaryNode();
2254     switch (nodeBinary->getOp())
2255     {
2256         case EOpIndexDirect:
2257         {
2258             int index = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2259 
2260             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2261             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_" << index;
2262             return ImmutableString(prefixSink.str());
2263         }
2264         case EOpIndexDirectStruct:
2265         {
2266             const TStructure *s = nodeBinary->getLeft()->getAsTyped()->getType().getStruct();
2267             int index           = nodeBinary->getRight()->getAsConstantUnion()->getIConst(0);
2268             const TField *field = s->fields()[index];
2269 
2270             std::stringstream prefixSink = sh::InitializeStream<std::stringstream>();
2271             prefixSink << samplerNamePrefixFromStruct(nodeBinary->getLeft()) << "_"
2272                        << field->name();
2273             return ImmutableString(prefixSink.str());
2274         }
2275         default:
2276             UNREACHABLE();
2277             return kEmptyImmutableString;
2278     }
2279 }
2280 
visitBlock(Visit visit,TIntermBlock * node)2281 bool OutputHLSL::visitBlock(Visit visit, TIntermBlock *node)
2282 {
2283     TInfoSinkBase &out = getInfoSink();
2284 
2285     bool isMainBlock = mInsideMain && getParentNode()->getAsFunctionDefinition();
2286 
2287     if (mInsideFunction)
2288     {
2289         outputLineDirective(out, node->getLine().first_line);
2290         out << "{\n";
2291         if (isMainBlock)
2292         {
2293             if (mShaderType == GL_COMPUTE_SHADER)
2294             {
2295                 out << "initGLBuiltins(input);\n";
2296             }
2297             else
2298             {
2299                 out << "@@ MAIN PROLOGUE @@\n";
2300             }
2301         }
2302     }
2303 
2304     for (TIntermNode *statement : *node->getSequence())
2305     {
2306         outputLineDirective(out, statement->getLine().first_line);
2307 
2308         statement->traverse(this);
2309 
2310         // Don't output ; after case labels, they're terminated by :
2311         // This is needed especially since outputting a ; after a case statement would turn empty
2312         // case statements into non-empty case statements, disallowing fall-through from them.
2313         // Also the output code is clearer if we don't output ; after statements where it is not
2314         // needed:
2315         //  * if statements
2316         //  * switch statements
2317         //  * blocks
2318         //  * function definitions
2319         //  * loops (do-while loops output the semicolon in VisitLoop)
2320         //  * declarations that don't generate output.
2321         if (statement->getAsCaseNode() == nullptr && statement->getAsIfElseNode() == nullptr &&
2322             statement->getAsBlock() == nullptr && statement->getAsLoopNode() == nullptr &&
2323             statement->getAsSwitchNode() == nullptr &&
2324             statement->getAsFunctionDefinition() == nullptr &&
2325             (statement->getAsDeclarationNode() == nullptr ||
2326              IsDeclarationWrittenOut(statement->getAsDeclarationNode())) &&
2327             statement->getAsGlobalQualifierDeclarationNode() == nullptr)
2328         {
2329             out << ";\n";
2330         }
2331     }
2332 
2333     if (mInsideFunction)
2334     {
2335         outputLineDirective(out, node->getLine().last_line);
2336         if (isMainBlock && shaderNeedsGenerateOutput())
2337         {
2338             // We could have an empty main, a main function without a branch at the end, or a main
2339             // function with a discard statement at the end. In these cases we need to add a return
2340             // statement.
2341             bool needReturnStatement =
2342                 node->getSequence()->empty() || !node->getSequence()->back()->getAsBranchNode() ||
2343                 node->getSequence()->back()->getAsBranchNode()->getFlowOp() != EOpReturn;
2344             if (needReturnStatement)
2345             {
2346                 out << "return " << generateOutputCall() << ";\n";
2347             }
2348         }
2349         out << "}\n";
2350     }
2351 
2352     return false;
2353 }
2354 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)2355 bool OutputHLSL::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
2356 {
2357     TInfoSinkBase &out = getInfoSink();
2358 
2359     ASSERT(mCurrentFunctionMetadata == nullptr);
2360 
2361     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2362     ASSERT(index != CallDAG::InvalidIndex);
2363     mCurrentFunctionMetadata = &mASTMetadataList[index];
2364 
2365     const TFunction *func = node->getFunction();
2366 
2367     if (func->isMain())
2368     {
2369         // The stub strings below are replaced when shader is dynamically defined by its layout:
2370         switch (mShaderType)
2371         {
2372             case GL_VERTEX_SHADER:
2373                 out << "@@ VERTEX ATTRIBUTES @@\n\n"
2374                     << "@@ VERTEX OUTPUT @@\n\n"
2375                     << "VS_OUTPUT main(VS_INPUT input)";
2376                 break;
2377             case GL_FRAGMENT_SHADER:
2378                 out << "@@ PIXEL OUTPUT @@\n\n";
2379                 if (mIsEarlyFragmentTestsSpecified)
2380                 {
2381                     out << "[earlydepthstencil]\n";
2382                 }
2383                 out << "PS_OUTPUT main(@@ PIXEL MAIN PARAMETERS @@)";
2384                 break;
2385             case GL_COMPUTE_SHADER:
2386                 out << "[numthreads(" << mWorkGroupSize[0] << ", " << mWorkGroupSize[1] << ", "
2387                     << mWorkGroupSize[2] << ")]\n";
2388                 out << "void main(CS_INPUT input)";
2389                 break;
2390             default:
2391                 UNREACHABLE();
2392                 break;
2393         }
2394     }
2395     else
2396     {
2397         out << TypeString(node->getFunctionPrototype()->getType()) << " ";
2398         out << DecorateFunctionIfNeeded(func) << DisambiguateFunctionName(func)
2399             << (mOutputLod0Function ? "Lod0(" : "(");
2400 
2401         size_t paramCount = func->getParamCount();
2402         for (unsigned int i = 0; i < paramCount; i++)
2403         {
2404             const TVariable *param = func->getParam(i);
2405             ensureStructDefined(param->getType());
2406 
2407             writeParameter(param, out);
2408 
2409             if (i < paramCount - 1)
2410             {
2411                 out << ", ";
2412             }
2413         }
2414 
2415         out << ")\n";
2416     }
2417 
2418     mInsideFunction = true;
2419     if (func->isMain())
2420     {
2421         mInsideMain = true;
2422     }
2423     // The function body node will output braces.
2424     node->getBody()->traverse(this);
2425     mInsideFunction = false;
2426     mInsideMain     = false;
2427 
2428     mCurrentFunctionMetadata = nullptr;
2429 
2430     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2431     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2432     {
2433         ASSERT(!node->getFunction()->isMain());
2434         mOutputLod0Function = true;
2435         node->traverse(this);
2436         mOutputLod0Function = false;
2437     }
2438 
2439     return false;
2440 }
2441 
visitDeclaration(Visit visit,TIntermDeclaration * node)2442 bool OutputHLSL::visitDeclaration(Visit visit, TIntermDeclaration *node)
2443 {
2444     if (visit == PreVisit)
2445     {
2446         TIntermSequence *sequence = node->getSequence();
2447         TIntermTyped *declarator  = (*sequence)[0]->getAsTyped();
2448         ASSERT(sequence->size() == 1);
2449         ASSERT(declarator);
2450 
2451         if (IsDeclarationWrittenOut(node))
2452         {
2453             TInfoSinkBase &out = getInfoSink();
2454             ensureStructDefined(declarator->getType());
2455 
2456             if (!declarator->getAsSymbolNode() ||
2457                 declarator->getAsSymbolNode()->variable().symbolType() !=
2458                     SymbolType::Empty)  // Variable declaration
2459             {
2460                 if (declarator->getQualifier() == EvqShared)
2461                 {
2462                     out << "groupshared ";
2463                 }
2464                 else if (!mInsideFunction)
2465                 {
2466                     out << "static ";
2467                 }
2468 
2469                 out << TypeString(declarator->getType()) + " ";
2470 
2471                 TIntermSymbol *symbol = declarator->getAsSymbolNode();
2472 
2473                 if (symbol)
2474                 {
2475                     symbol->traverse(this);
2476                     out << ArrayString(symbol->getType());
2477                     // Temporarily disable shadred memory initialization. It is very slow for D3D11
2478                     // drivers to compile a compute shader if we add code to initialize a
2479                     // groupshared array variable with a large array size. And maybe produce
2480                     // incorrect result. See http://anglebug.com/40644676.
2481                     if (declarator->getQualifier() != EvqShared)
2482                     {
2483                         out << " = " + zeroInitializer(symbol->getType());
2484                     }
2485                 }
2486                 else
2487                 {
2488                     declarator->traverse(this);
2489                 }
2490             }
2491         }
2492         else if (IsVaryingOut(declarator->getQualifier()))
2493         {
2494             TIntermSymbol *symbol = declarator->getAsSymbolNode();
2495             ASSERT(symbol);  // Varying declarations can't have initializers.
2496 
2497             const TVariable &variable = symbol->variable();
2498 
2499             if (variable.symbolType() != SymbolType::Empty)
2500             {
2501                 // Vertex outputs which are declared but not written to should still be declared to
2502                 // allow successful linking.
2503                 mReferencedVaryings[symbol->uniqueId().get()] = &variable;
2504             }
2505         }
2506     }
2507     return false;
2508 }
2509 
visitGlobalQualifierDeclaration(Visit visit,TIntermGlobalQualifierDeclaration * node)2510 bool OutputHLSL::visitGlobalQualifierDeclaration(Visit visit,
2511                                                  TIntermGlobalQualifierDeclaration *node)
2512 {
2513     // Do not do any translation
2514     return false;
2515 }
2516 
visitFunctionPrototype(TIntermFunctionPrototype * node)2517 void OutputHLSL::visitFunctionPrototype(TIntermFunctionPrototype *node)
2518 {
2519     TInfoSinkBase &out = getInfoSink();
2520 
2521     size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2522     // Skip the prototype if it is not implemented (and thus not used)
2523     if (index == CallDAG::InvalidIndex)
2524     {
2525         return;
2526     }
2527 
2528     const TFunction *func = node->getFunction();
2529 
2530     TString name = DecorateFunctionIfNeeded(func);
2531     out << TypeString(node->getType()) << " " << name << DisambiguateFunctionName(func)
2532         << (mOutputLod0Function ? "Lod0(" : "(");
2533 
2534     size_t paramCount = func->getParamCount();
2535     for (unsigned int i = 0; i < paramCount; i++)
2536     {
2537         writeParameter(func->getParam(i), out);
2538 
2539         if (i < paramCount - 1)
2540         {
2541             out << ", ";
2542         }
2543     }
2544 
2545     out << ");\n";
2546 
2547     // Also prototype the Lod0 variant if needed
2548     bool needsLod0 = mASTMetadataList[index].mNeedsLod0;
2549     if (needsLod0 && !mOutputLod0Function && mShaderType == GL_FRAGMENT_SHADER)
2550     {
2551         mOutputLod0Function = true;
2552         node->traverse(this);
2553         mOutputLod0Function = false;
2554     }
2555 }
2556 
visitAggregate(Visit visit,TIntermAggregate * node)2557 bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
2558 {
2559     TInfoSinkBase &out = getInfoSink();
2560 
2561     switch (node->getOp())
2562     {
2563         case EOpCallFunctionInAST:
2564         case EOpCallInternalRawFunction:
2565         default:
2566         {
2567             TIntermSequence *arguments = node->getSequence();
2568 
2569             bool lod0 = (mInsideDiscontinuousLoop || mOutputLod0Function) &&
2570                         mShaderType == GL_FRAGMENT_SHADER;
2571 
2572             // No raw function is expected.
2573             ASSERT(node->getOp() != EOpCallInternalRawFunction);
2574 
2575             if (node->getOp() == EOpCallFunctionInAST)
2576             {
2577                 if (node->isArray())
2578                 {
2579                     UNIMPLEMENTED();
2580                 }
2581                 size_t index = mCallDag.findIndex(node->getFunction()->uniqueId());
2582                 ASSERT(index != CallDAG::InvalidIndex);
2583                 lod0 &= mASTMetadataList[index].mNeedsLod0;
2584 
2585                 out << DecorateFunctionIfNeeded(node->getFunction());
2586                 out << DisambiguateFunctionName(node->getSequence());
2587                 out << (lod0 ? "Lod0(" : "(");
2588             }
2589             else if (node->getFunction()->isImageFunction())
2590             {
2591                 const ImmutableString &name              = node->getFunction()->name();
2592                 TType type                               = (*arguments)[0]->getAsTyped()->getType();
2593                 const ImmutableString &imageFunctionName = mImageFunctionHLSL->useImageFunction(
2594                     name, type.getBasicType(), type.getLayoutQualifier().imageInternalFormat,
2595                     type.getMemoryQualifier().readonly);
2596                 out << imageFunctionName << "(";
2597             }
2598             else if (node->getFunction()->isAtomicCounterFunction())
2599             {
2600                 const ImmutableString &name = node->getFunction()->name();
2601                 ImmutableString atomicFunctionName =
2602                     mAtomicCounterFunctionHLSL->useAtomicCounterFunction(name);
2603                 out << atomicFunctionName << "(";
2604             }
2605             else
2606             {
2607                 const ImmutableString &name = node->getFunction()->name();
2608                 TBasicType samplerType = (*arguments)[0]->getAsTyped()->getType().getBasicType();
2609                 int coords = 0;  // textureSize(gsampler2DMS) doesn't have a second argument.
2610                 if (arguments->size() > 1)
2611                 {
2612                     coords = (*arguments)[1]->getAsTyped()->getNominalSize();
2613                 }
2614                 const ImmutableString &textureFunctionName =
2615                     mTextureFunctionHLSL->useTextureFunction(name, samplerType, coords,
2616                                                              arguments->size(), lod0, mShaderType);
2617                 out << textureFunctionName << "(";
2618             }
2619 
2620             for (TIntermSequence::iterator arg = arguments->begin(); arg != arguments->end(); arg++)
2621             {
2622                 TIntermTyped *typedArg = (*arg)->getAsTyped();
2623 
2624                 (*arg)->traverse(this);
2625 
2626                 if (typedArg->getType().isStructureContainingSamplers())
2627                 {
2628                     const TType &argType = typedArg->getType();
2629                     TVector<const TVariable *> samplerSymbols;
2630                     ImmutableString structName = samplerNamePrefixFromStruct(typedArg);
2631                     std::string namePrefix     = "angle_";
2632                     namePrefix += structName.data();
2633                     argType.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols,
2634                                                  nullptr, mSymbolTable);
2635                     for (const TVariable *sampler : samplerSymbols)
2636                     {
2637                         // In case of HLSL 4.1+, this symbol is the sampler index, and in case
2638                         // of D3D9, it's the sampler variable.
2639                         out << ", " << sampler->name();
2640                     }
2641                 }
2642 
2643                 if (arg < arguments->end() - 1)
2644                 {
2645                     out << ", ";
2646                 }
2647             }
2648 
2649             out << ")";
2650 
2651             return false;
2652         }
2653         case EOpConstruct:
2654             outputConstructor(out, visit, node);
2655             break;
2656         case EOpEqualComponentWise:
2657             outputTriplet(out, visit, "(", " == ", ")");
2658             break;
2659         case EOpNotEqualComponentWise:
2660             outputTriplet(out, visit, "(", " != ", ")");
2661             break;
2662         case EOpLessThanComponentWise:
2663             outputTriplet(out, visit, "(", " < ", ")");
2664             break;
2665         case EOpGreaterThanComponentWise:
2666             outputTriplet(out, visit, "(", " > ", ")");
2667             break;
2668         case EOpLessThanEqualComponentWise:
2669             outputTriplet(out, visit, "(", " <= ", ")");
2670             break;
2671         case EOpGreaterThanEqualComponentWise:
2672             outputTriplet(out, visit, "(", " >= ", ")");
2673             break;
2674         case EOpMod:
2675             ASSERT(node->getUseEmulatedFunction());
2676             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2677             break;
2678         case EOpModf:
2679             outputTriplet(out, visit, "modf(", ", ", ")");
2680             break;
2681         case EOpPow:
2682             outputTriplet(out, visit, "pow(", ", ", ")");
2683             break;
2684         case EOpAtan:
2685             ASSERT(node->getSequence()->size() == 2);  // atan(x) is a unary operator
2686             ASSERT(node->getUseEmulatedFunction());
2687             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2688             break;
2689         case EOpMin:
2690             outputTriplet(out, visit, "min(", ", ", ")");
2691             break;
2692         case EOpMax:
2693             outputTriplet(out, visit, "max(", ", ", ")");
2694             break;
2695         case EOpClamp:
2696             outputTriplet(out, visit, "clamp(", ", ", ")");
2697             break;
2698         case EOpMix:
2699         {
2700             TIntermTyped *lastParamNode = (*(node->getSequence()))[2]->getAsTyped();
2701             if (lastParamNode->getType().getBasicType() == EbtBool)
2702             {
2703                 // There is no HLSL equivalent for ESSL3 built-in "genType mix (genType x, genType
2704                 // y, genBType a)",
2705                 // so use emulated version.
2706                 ASSERT(node->getUseEmulatedFunction());
2707                 writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2708             }
2709             else
2710             {
2711                 outputTriplet(out, visit, "lerp(", ", ", ")");
2712             }
2713             break;
2714         }
2715         case EOpStep:
2716             outputTriplet(out, visit, "step(", ", ", ")");
2717             break;
2718         case EOpSmoothstep:
2719             outputTriplet(out, visit, "smoothstep(", ", ", ")");
2720             break;
2721         case EOpFma:
2722             outputTriplet(out, visit, "mad(", ", ", ")");
2723             break;
2724         case EOpFrexp:
2725         case EOpLdexp:
2726             ASSERT(node->getUseEmulatedFunction());
2727             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2728             break;
2729         case EOpDistance:
2730             outputTriplet(out, visit, "distance(", ", ", ")");
2731             break;
2732         case EOpDot:
2733             outputTriplet(out, visit, "dot(", ", ", ")");
2734             break;
2735         case EOpCross:
2736             outputTriplet(out, visit, "cross(", ", ", ")");
2737             break;
2738         case EOpFaceforward:
2739             ASSERT(node->getUseEmulatedFunction());
2740             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2741             break;
2742         case EOpReflect:
2743             outputTriplet(out, visit, "reflect(", ", ", ")");
2744             break;
2745         case EOpRefract:
2746             outputTriplet(out, visit, "refract(", ", ", ")");
2747             break;
2748         case EOpOuterProduct:
2749             ASSERT(node->getUseEmulatedFunction());
2750             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2751             break;
2752         case EOpMatrixCompMult:
2753             outputTriplet(out, visit, "(", " * ", ")");
2754             break;
2755         case EOpBitfieldExtract:
2756         case EOpBitfieldInsert:
2757         case EOpUaddCarry:
2758         case EOpUsubBorrow:
2759         case EOpUmulExtended:
2760         case EOpImulExtended:
2761             ASSERT(node->getUseEmulatedFunction());
2762             writeEmulatedFunctionTriplet(out, visit, node->getFunction());
2763             break;
2764         case EOpDFdx:
2765             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2766             {
2767                 outputTriplet(out, visit, "(", "", ", 0.0)");
2768             }
2769             else
2770             {
2771                 outputTriplet(out, visit, "ddx(", "", ")");
2772             }
2773             break;
2774         case EOpDFdy:
2775             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2776             {
2777                 outputTriplet(out, visit, "(", "", ", 0.0)");
2778             }
2779             else
2780             {
2781                 outputTriplet(out, visit, "ddy(", "", ")");
2782             }
2783             break;
2784         case EOpFwidth:
2785             if (mInsideDiscontinuousLoop || mOutputLod0Function)
2786             {
2787                 outputTriplet(out, visit, "(", "", ", 0.0)");
2788             }
2789             else
2790             {
2791                 outputTriplet(out, visit, "fwidth(", "", ")");
2792             }
2793             break;
2794         case EOpInterpolateAtCentroid:
2795         {
2796             TIntermTyped *interpolantNode = (*(node->getSequence()))[0]->getAsTyped();
2797             if (!IsFlatInterpolant(interpolantNode))
2798             {
2799                 outputTriplet(out, visit, "EvaluateAttributeCentroid(", "", ")");
2800             }
2801             break;
2802         }
2803         case EOpInterpolateAtSample:
2804         {
2805             TIntermTyped *interpolantNode = (*(node->getSequence()))[0]->getAsTyped();
2806             if (!IsFlatInterpolant(interpolantNode))
2807             {
2808                 mUsesNumSamples = true;
2809                 outputTriplet(out, visit, "EvaluateAttributeAtSample(", ", clamp(",
2810                               ", 0, gl_NumSamples - 1))");
2811             }
2812             else
2813             {
2814                 const TString &functionName = addFlatEvaluateFunction(
2815                     interpolantNode->getType(), *StaticType::GetBasic<EbtInt, EbpUndefined, 1>());
2816                 outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
2817             }
2818             break;
2819         }
2820         case EOpInterpolateAtOffset:
2821         {
2822             TIntermTyped *interpolantNode = (*(node->getSequence()))[0]->getAsTyped();
2823             if (!IsFlatInterpolant(interpolantNode))
2824             {
2825                 outputTriplet(out, visit, "EvaluateAttributeSnapped(", ", int2((", ") * 16.0))");
2826             }
2827             else
2828             {
2829                 const TString &functionName = addFlatEvaluateFunction(
2830                     interpolantNode->getType(), *StaticType::GetBasic<EbtFloat, EbpUndefined, 2>());
2831                 outputTriplet(out, visit, (functionName + "(").c_str(), ", ", ")");
2832             }
2833             break;
2834         }
2835         case EOpBarrier:
2836             // barrier() is translated to GroupMemoryBarrierWithGroupSync(), which is the
2837             // cheapest *WithGroupSync() function, without any functionality loss, but
2838             // with the potential for severe performance loss.
2839             outputTriplet(out, visit, "GroupMemoryBarrierWithGroupSync(", "", ")");
2840             break;
2841         case EOpMemoryBarrierShared:
2842             outputTriplet(out, visit, "GroupMemoryBarrier(", "", ")");
2843             break;
2844         case EOpMemoryBarrierAtomicCounter:
2845         case EOpMemoryBarrierBuffer:
2846         case EOpMemoryBarrierImage:
2847             outputTriplet(out, visit, "DeviceMemoryBarrier(", "", ")");
2848             break;
2849         case EOpGroupMemoryBarrier:
2850         case EOpMemoryBarrier:
2851             outputTriplet(out, visit, "AllMemoryBarrier(", "", ")");
2852             break;
2853 
2854         // Single atomic function calls without return value.
2855         // e.g. atomicAdd(dest, value) should be translated into InterlockedAdd(dest, value).
2856         case EOpAtomicAdd:
2857         case EOpAtomicMin:
2858         case EOpAtomicMax:
2859         case EOpAtomicAnd:
2860         case EOpAtomicOr:
2861         case EOpAtomicXor:
2862         // The parameter 'original_value' of InterlockedExchange(dest, value, original_value)
2863         // and InterlockedCompareExchange(dest, compare_value, value, original_value) is not
2864         // optional.
2865         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedexchange
2866         // https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/interlockedcompareexchange
2867         // So all the call of atomicExchange(dest, value) and atomicCompSwap(dest,
2868         // compare_value, value) should all be modified into the form of "int temp; temp =
2869         // atomicExchange(dest, value);" and "int temp; temp = atomicCompSwap(dest,
2870         // compare_value, value);" in the intermediate tree before traversing outputHLSL.
2871         case EOpAtomicExchange:
2872         case EOpAtomicCompSwap:
2873         {
2874             ASSERT(node->getChildCount() > 1);
2875             TIntermTyped *memNode = (*node->getSequence())[0]->getAsTyped();
2876             if (IsInShaderStorageBlock(memNode))
2877             {
2878                 // Atomic memory functions for SSBO.
2879                 // "_ssbo_atomicXXX_TYPE(RWByteAddressBuffer buffer, uint loc" is written to |out|.
2880                 mSSBOOutputHLSL->outputAtomicMemoryFunctionCallPrefix(memNode, node->getOp());
2881                 // Write the rest argument list to |out|.
2882                 for (size_t i = 1; i < node->getChildCount(); i++)
2883                 {
2884                     out << ", ";
2885                     TIntermTyped *argument = (*node->getSequence())[i]->getAsTyped();
2886                     if (IsInShaderStorageBlock(argument))
2887                     {
2888                         mSSBOOutputHLSL->outputLoadFunctionCall(argument);
2889                     }
2890                     else
2891                     {
2892                         argument->traverse(this);
2893                     }
2894                 }
2895 
2896                 out << ")";
2897                 return false;
2898             }
2899             else
2900             {
2901                 // Atomic memory functions for shared variable.
2902                 if (node->getOp() != EOpAtomicExchange && node->getOp() != EOpAtomicCompSwap)
2903                 {
2904                     outputTriplet(out, visit,
2905                                   GetHLSLAtomicFunctionStringAndLeftParenthesis(node->getOp()), ",",
2906                                   ")");
2907                 }
2908                 else
2909                 {
2910                     UNREACHABLE();
2911                 }
2912             }
2913 
2914             break;
2915         }
2916     }
2917 
2918     return true;
2919 }
2920 
writeIfElse(TInfoSinkBase & out,TIntermIfElse * node)2921 void OutputHLSL::writeIfElse(TInfoSinkBase &out, TIntermIfElse *node)
2922 {
2923     out << "if (";
2924 
2925     node->getCondition()->traverse(this);
2926 
2927     out << ")\n";
2928 
2929     outputLineDirective(out, node->getLine().first_line);
2930 
2931     bool discard = false;
2932 
2933     if (node->getTrueBlock())
2934     {
2935         // The trueBlock child node will output braces.
2936         node->getTrueBlock()->traverse(this);
2937 
2938         // Detect true discard
2939         discard = (discard || FindDiscard::search(node->getTrueBlock()));
2940     }
2941     else
2942     {
2943         // TODO(oetuaho): Check if the semicolon inside is necessary.
2944         // It's there as a result of conservative refactoring of the output.
2945         out << "{;}\n";
2946     }
2947 
2948     outputLineDirective(out, node->getLine().first_line);
2949 
2950     if (node->getFalseBlock())
2951     {
2952         out << "else\n";
2953 
2954         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2955 
2956         // The falseBlock child node will output braces.
2957         node->getFalseBlock()->traverse(this);
2958 
2959         outputLineDirective(out, node->getFalseBlock()->getLine().first_line);
2960 
2961         // Detect false discard
2962         discard = (discard || FindDiscard::search(node->getFalseBlock()));
2963     }
2964 
2965     // ANGLE issue 486: Detect problematic conditional discard
2966     if (discard)
2967     {
2968         mUsesDiscardRewriting = true;
2969     }
2970 }
2971 
visitTernary(Visit,TIntermTernary *)2972 bool OutputHLSL::visitTernary(Visit, TIntermTernary *)
2973 {
2974     // Ternary ops should have been already converted to something else in the AST. HLSL ternary
2975     // operator doesn't short-circuit, so it's not the same as the GLSL ternary operator.
2976     UNREACHABLE();
2977     return false;
2978 }
2979 
visitIfElse(Visit visit,TIntermIfElse * node)2980 bool OutputHLSL::visitIfElse(Visit visit, TIntermIfElse *node)
2981 {
2982     TInfoSinkBase &out = getInfoSink();
2983 
2984     ASSERT(mInsideFunction);
2985 
2986     // D3D errors when there is a gradient operation in a loop in an unflattened if.
2987     if (mShaderType == GL_FRAGMENT_SHADER && mCurrentFunctionMetadata->hasGradientLoop(node))
2988     {
2989         out << "FLATTEN ";
2990     }
2991 
2992     writeIfElse(out, node);
2993 
2994     return false;
2995 }
2996 
visitSwitch(Visit visit,TIntermSwitch * node)2997 bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
2998 {
2999     TInfoSinkBase &out = getInfoSink();
3000 
3001     ASSERT(node->getStatementList());
3002     if (visit == PreVisit)
3003     {
3004         node->setStatementList(RemoveSwitchFallThrough(node->getStatementList(), mPerfDiagnostics));
3005     }
3006     outputTriplet(out, visit, "switch (", ") ", "");
3007     // The curly braces get written when visiting the statementList block.
3008     return true;
3009 }
3010 
visitCase(Visit visit,TIntermCase * node)3011 bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
3012 {
3013     TInfoSinkBase &out = getInfoSink();
3014 
3015     if (node->hasCondition())
3016     {
3017         outputTriplet(out, visit, "case (", "", "):\n");
3018         return true;
3019     }
3020     else
3021     {
3022         out << "default:\n";
3023         return false;
3024     }
3025 }
3026 
visitConstantUnion(TIntermConstantUnion * node)3027 void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
3028 {
3029     TInfoSinkBase &out = getInfoSink();
3030     writeConstantUnion(out, node->getType(), node->getConstantValue());
3031 }
3032 
visitLoop(Visit visit,TIntermLoop * node)3033 bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
3034 {
3035     mNestedLoopDepth++;
3036 
3037     bool wasDiscontinuous = mInsideDiscontinuousLoop;
3038     mInsideDiscontinuousLoop =
3039         mInsideDiscontinuousLoop || mCurrentFunctionMetadata->mDiscontinuousLoops.count(node) > 0;
3040 
3041     TInfoSinkBase &out = getInfoSink();
3042 
3043     if (mOutputType == SH_HLSL_3_0_OUTPUT)
3044     {
3045         if (handleExcessiveLoop(out, node))
3046         {
3047             mInsideDiscontinuousLoop = wasDiscontinuous;
3048             mNestedLoopDepth--;
3049 
3050             return false;
3051         }
3052     }
3053 
3054     const char *unroll = mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3055     if (node->getType() == ELoopDoWhile)
3056     {
3057         out << "{" << unroll << " do\n";
3058 
3059         outputLineDirective(out, node->getLine().first_line);
3060     }
3061     else
3062     {
3063         out << "{" << unroll << " for(";
3064 
3065         if (node->getInit())
3066         {
3067             node->getInit()->traverse(this);
3068         }
3069 
3070         out << "; ";
3071 
3072         if (node->getCondition())
3073         {
3074             node->getCondition()->traverse(this);
3075         }
3076 
3077         out << "; ";
3078 
3079         if (node->getExpression())
3080         {
3081             node->getExpression()->traverse(this);
3082         }
3083 
3084         out << ")\n";
3085 
3086         outputLineDirective(out, node->getLine().first_line);
3087     }
3088 
3089     // The loop body node will output braces.
3090     node->getBody()->traverse(this);
3091 
3092     outputLineDirective(out, node->getLine().first_line);
3093 
3094     if (node->getType() == ELoopDoWhile)
3095     {
3096         outputLineDirective(out, node->getCondition()->getLine().first_line);
3097         out << "while (";
3098 
3099         node->getCondition()->traverse(this);
3100 
3101         out << ");\n";
3102     }
3103 
3104     out << "}\n";
3105 
3106     mInsideDiscontinuousLoop = wasDiscontinuous;
3107     mNestedLoopDepth--;
3108 
3109     return false;
3110 }
3111 
visitBranch(Visit visit,TIntermBranch * node)3112 bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
3113 {
3114     if (visit == PreVisit)
3115     {
3116         TInfoSinkBase &out = getInfoSink();
3117 
3118         switch (node->getFlowOp())
3119         {
3120             case EOpKill:
3121                 out << "discard";
3122                 break;
3123             case EOpBreak:
3124                 if (mNestedLoopDepth > 1)
3125                 {
3126                     mUsesNestedBreak = true;
3127                 }
3128 
3129                 if (mExcessiveLoopIndex)
3130                 {
3131                     out << "{Break";
3132                     mExcessiveLoopIndex->traverse(this);
3133                     out << " = true; break;}\n";
3134                 }
3135                 else
3136                 {
3137                     out << "break";
3138                 }
3139                 break;
3140             case EOpContinue:
3141                 out << "continue";
3142                 break;
3143             case EOpReturn:
3144                 if (node->getExpression())
3145                 {
3146                     ASSERT(!mInsideMain);
3147                     out << "return ";
3148                     if (IsInShaderStorageBlock(node->getExpression()))
3149                     {
3150                         mSSBOOutputHLSL->outputLoadFunctionCall(node->getExpression());
3151                         return false;
3152                     }
3153                 }
3154                 else
3155                 {
3156                     if (mInsideMain && shaderNeedsGenerateOutput())
3157                     {
3158                         out << "return " << generateOutputCall();
3159                     }
3160                     else
3161                     {
3162                         out << "return";
3163                     }
3164                 }
3165                 break;
3166             default:
3167                 UNREACHABLE();
3168         }
3169     }
3170 
3171     return true;
3172 }
3173 
3174 // Handle loops with more than 254 iterations (unsupported by D3D9) by splitting them
3175 // (The D3D documentation says 255 iterations, but the compiler complains at anything more than
3176 // 254).
handleExcessiveLoop(TInfoSinkBase & out,TIntermLoop * node)3177 bool OutputHLSL::handleExcessiveLoop(TInfoSinkBase &out, TIntermLoop *node)
3178 {
3179     const int MAX_LOOP_ITERATIONS = 254;
3180 
3181     // Parse loops of the form:
3182     // for(int index = initial; index [comparator] limit; index += increment)
3183     TIntermSymbol *index = nullptr;
3184     TOperator comparator = EOpNull;
3185     int initial          = 0;
3186     int limit            = 0;
3187     int increment        = 0;
3188 
3189     // Parse index name and intial value
3190     if (node->getInit())
3191     {
3192         TIntermDeclaration *init = node->getInit()->getAsDeclarationNode();
3193 
3194         if (init)
3195         {
3196             TIntermSequence *sequence = init->getSequence();
3197             TIntermTyped *variable    = (*sequence)[0]->getAsTyped();
3198 
3199             if (variable && variable->getQualifier() == EvqTemporary)
3200             {
3201                 TIntermBinary *assign = variable->getAsBinaryNode();
3202 
3203                 if (assign != nullptr && assign->getOp() == EOpInitialize)
3204                 {
3205                     TIntermSymbol *symbol          = assign->getLeft()->getAsSymbolNode();
3206                     TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
3207 
3208                     if (symbol && constant)
3209                     {
3210                         if (constant->getBasicType() == EbtInt && constant->isScalar())
3211                         {
3212                             index   = symbol;
3213                             initial = constant->getIConst(0);
3214                         }
3215                     }
3216                 }
3217             }
3218         }
3219     }
3220 
3221     // Parse comparator and limit value
3222     if (index != nullptr && node->getCondition())
3223     {
3224         TIntermBinary *test = node->getCondition()->getAsBinaryNode();
3225 
3226         if (test && test->getLeft()->getAsSymbolNode()->uniqueId() == index->uniqueId())
3227         {
3228             TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
3229 
3230             if (constant)
3231             {
3232                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3233                 {
3234                     comparator = test->getOp();
3235                     limit      = constant->getIConst(0);
3236                 }
3237             }
3238         }
3239     }
3240 
3241     // Parse increment
3242     if (index != nullptr && comparator != EOpNull && node->getExpression())
3243     {
3244         TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
3245         TIntermUnary *unaryTerminal   = node->getExpression()->getAsUnaryNode();
3246 
3247         if (binaryTerminal)
3248         {
3249             TOperator op                   = binaryTerminal->getOp();
3250             TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
3251 
3252             if (constant)
3253             {
3254                 if (constant->getBasicType() == EbtInt && constant->isScalar())
3255                 {
3256                     int value = constant->getIConst(0);
3257 
3258                     switch (op)
3259                     {
3260                         case EOpAddAssign:
3261                             increment = value;
3262                             break;
3263                         case EOpSubAssign:
3264                             increment = -value;
3265                             break;
3266                         default:
3267                             UNIMPLEMENTED();
3268                     }
3269                 }
3270             }
3271         }
3272         else if (unaryTerminal)
3273         {
3274             TOperator op = unaryTerminal->getOp();
3275 
3276             switch (op)
3277             {
3278                 case EOpPostIncrement:
3279                     increment = 1;
3280                     break;
3281                 case EOpPostDecrement:
3282                     increment = -1;
3283                     break;
3284                 case EOpPreIncrement:
3285                     increment = 1;
3286                     break;
3287                 case EOpPreDecrement:
3288                     increment = -1;
3289                     break;
3290                 default:
3291                     UNIMPLEMENTED();
3292             }
3293         }
3294     }
3295 
3296     if (index != nullptr && comparator != EOpNull && increment != 0)
3297     {
3298         if (comparator == EOpLessThanEqual)
3299         {
3300             comparator = EOpLessThan;
3301             limit += 1;
3302         }
3303 
3304         if (comparator == EOpLessThan)
3305         {
3306             int iterations = (limit - initial) / increment;
3307 
3308             if (iterations <= MAX_LOOP_ITERATIONS)
3309             {
3310                 return false;  // Not an excessive loop
3311             }
3312 
3313             TIntermSymbol *restoreIndex = mExcessiveLoopIndex;
3314             mExcessiveLoopIndex         = index;
3315 
3316             out << "{int ";
3317             index->traverse(this);
3318             out << ";\n"
3319                    "bool Break";
3320             index->traverse(this);
3321             out << " = false;\n";
3322 
3323             bool firstLoopFragment = true;
3324 
3325             while (iterations > 0)
3326             {
3327                 int clampedLimit = initial + increment * std::min(MAX_LOOP_ITERATIONS, iterations);
3328 
3329                 if (!firstLoopFragment)
3330                 {
3331                     out << "if (!Break";
3332                     index->traverse(this);
3333                     out << ") {\n";
3334                 }
3335 
3336                 if (iterations <= MAX_LOOP_ITERATIONS)  // Last loop fragment
3337                 {
3338                     mExcessiveLoopIndex = nullptr;  // Stops setting the Break flag
3339                 }
3340 
3341                 // for(int index = initial; index < clampedLimit; index += increment)
3342                 const char *unroll =
3343                     mCurrentFunctionMetadata->hasGradientInCallGraph(node) ? "LOOP" : "";
3344 
3345                 out << unroll << " for(";
3346                 index->traverse(this);
3347                 out << " = ";
3348                 out << initial;
3349 
3350                 out << "; ";
3351                 index->traverse(this);
3352                 out << " < ";
3353                 out << clampedLimit;
3354 
3355                 out << "; ";
3356                 index->traverse(this);
3357                 out << " += ";
3358                 out << increment;
3359                 out << ")\n";
3360 
3361                 outputLineDirective(out, node->getLine().first_line);
3362                 out << "{\n";
3363 
3364                 node->getBody()->traverse(this);
3365 
3366                 outputLineDirective(out, node->getLine().first_line);
3367                 out << ";}\n";
3368 
3369                 if (!firstLoopFragment)
3370                 {
3371                     out << "}\n";
3372                 }
3373 
3374                 firstLoopFragment = false;
3375 
3376                 initial += MAX_LOOP_ITERATIONS * increment;
3377                 iterations -= MAX_LOOP_ITERATIONS;
3378             }
3379 
3380             out << "}";
3381 
3382             mExcessiveLoopIndex = restoreIndex;
3383 
3384             return true;
3385         }
3386         else
3387             UNIMPLEMENTED();
3388     }
3389 
3390     return false;  // Not handled as an excessive loop
3391 }
3392 
outputTriplet(TInfoSinkBase & out,Visit visit,const char * preString,const char * inString,const char * postString)3393 void OutputHLSL::outputTriplet(TInfoSinkBase &out,
3394                                Visit visit,
3395                                const char *preString,
3396                                const char *inString,
3397                                const char *postString)
3398 {
3399     if (visit == PreVisit)
3400     {
3401         out << preString;
3402     }
3403     else if (visit == InVisit)
3404     {
3405         out << inString;
3406     }
3407     else if (visit == PostVisit)
3408     {
3409         out << postString;
3410     }
3411 }
3412 
outputLineDirective(TInfoSinkBase & out,int line)3413 void OutputHLSL::outputLineDirective(TInfoSinkBase &out, int line)
3414 {
3415     if (mCompileOptions.lineDirectives && line > 0)
3416     {
3417         out << "\n";
3418         out << "#line " << line;
3419 
3420         if (mSourcePath)
3421         {
3422             out << " \"" << mSourcePath << "\"";
3423         }
3424 
3425         out << "\n";
3426     }
3427 }
3428 
writeParameter(const TVariable * param,TInfoSinkBase & out)3429 void OutputHLSL::writeParameter(const TVariable *param, TInfoSinkBase &out)
3430 {
3431     const TType &type    = param->getType();
3432     TQualifier qualifier = type.getQualifier();
3433 
3434     TString nameStr = DecorateVariableIfNeeded(*param);
3435     ASSERT(nameStr != "");  // HLSL demands named arguments, also for prototypes
3436 
3437     if (IsSampler(type.getBasicType()))
3438     {
3439         if (mOutputType == SH_HLSL_4_1_OUTPUT)
3440         {
3441             // Samplers are passed as indices to the sampler array.
3442             ASSERT(qualifier != EvqParamOut && qualifier != EvqParamInOut);
3443             out << "const uint " << nameStr << ArrayString(type);
3444             return;
3445         }
3446     }
3447 
3448     // If the parameter is an atomic counter, we need to add an extra parameter to keep track of the
3449     // buffer offset.
3450     if (IsAtomicCounter(type.getBasicType()))
3451     {
3452         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr << ", int "
3453             << nameStr << "_offset";
3454     }
3455     else
3456     {
3457         out << QualifierString(qualifier) << " " << TypeString(type) << " " << nameStr
3458             << ArrayString(type);
3459     }
3460 
3461     // If the structure parameter contains samplers, they need to be passed into the function as
3462     // separate parameters. HLSL doesn't natively support samplers in structs.
3463     if (type.isStructureContainingSamplers())
3464     {
3465         ASSERT(qualifier != EvqParamOut && qualifier != EvqParamInOut);
3466         TVector<const TVariable *> samplerSymbols;
3467         std::string namePrefix = "angle";
3468         namePrefix += nameStr.c_str();
3469         type.createSamplerSymbols(ImmutableString(namePrefix), "", &samplerSymbols, nullptr,
3470                                   mSymbolTable);
3471         for (const TVariable *sampler : samplerSymbols)
3472         {
3473             const TType &samplerType = sampler->getType();
3474             if (mOutputType == SH_HLSL_4_1_OUTPUT)
3475             {
3476                 out << ", const uint " << sampler->name() << ArrayString(samplerType);
3477             }
3478             else
3479             {
3480                 ASSERT(IsSampler(samplerType.getBasicType()));
3481                 out << ", " << QualifierString(qualifier) << " " << TypeString(samplerType) << " "
3482                     << sampler->name() << ArrayString(samplerType);
3483             }
3484         }
3485     }
3486 }
3487 
zeroInitializer(const TType & type) const3488 TString OutputHLSL::zeroInitializer(const TType &type) const
3489 {
3490     TString string;
3491 
3492     size_t size = type.getObjectSize();
3493     if (size >= kZeroCount)
3494     {
3495         mUseZeroArray = true;
3496     }
3497     string = GetZeroInitializer(size).c_str();
3498 
3499     return "{" + string + "}";
3500 }
3501 
outputConstructor(TInfoSinkBase & out,Visit visit,TIntermAggregate * node)3502 void OutputHLSL::outputConstructor(TInfoSinkBase &out, Visit visit, TIntermAggregate *node)
3503 {
3504     // Array constructors should have been already pruned from the code.
3505     ASSERT(!node->getType().isArray());
3506 
3507     if (visit == PreVisit)
3508     {
3509         TString constructorName;
3510         if (node->getBasicType() == EbtStruct)
3511         {
3512             constructorName = mStructureHLSL->addStructConstructor(*node->getType().getStruct());
3513         }
3514         else
3515         {
3516             constructorName =
3517                 mStructureHLSL->addBuiltInConstructor(node->getType(), node->getSequence());
3518         }
3519         out << constructorName << "(";
3520     }
3521     else if (visit == InVisit)
3522     {
3523         out << ", ";
3524     }
3525     else if (visit == PostVisit)
3526     {
3527         out << ")";
3528     }
3529 }
3530 
writeConstantUnion(TInfoSinkBase & out,const TType & type,const TConstantUnion * const constUnion)3531 const TConstantUnion *OutputHLSL::writeConstantUnion(TInfoSinkBase &out,
3532                                                      const TType &type,
3533                                                      const TConstantUnion *const constUnion)
3534 {
3535     ASSERT(!type.isArray());
3536 
3537     const TConstantUnion *constUnionIterated = constUnion;
3538 
3539     const TStructure *structure = type.getStruct();
3540     if (structure)
3541     {
3542         out << mStructureHLSL->addStructConstructor(*structure) << "(";
3543 
3544         const TFieldList &fields = structure->fields();
3545 
3546         for (size_t i = 0; i < fields.size(); i++)
3547         {
3548             const TType *fieldType = fields[i]->type();
3549             constUnionIterated     = writeConstantUnion(out, *fieldType, constUnionIterated);
3550 
3551             if (i != fields.size() - 1)
3552             {
3553                 out << ", ";
3554             }
3555         }
3556 
3557         out << ")";
3558     }
3559     else
3560     {
3561         size_t size    = type.getObjectSize();
3562         bool writeType = size > 1;
3563 
3564         if (writeType)
3565         {
3566             out << TypeString(type) << "(";
3567         }
3568         constUnionIterated = writeConstantUnionArray(out, constUnionIterated, size);
3569         if (writeType)
3570         {
3571             out << ")";
3572         }
3573     }
3574 
3575     return constUnionIterated;
3576 }
3577 
writeEmulatedFunctionTriplet(TInfoSinkBase & out,Visit visit,const TFunction * function)3578 void OutputHLSL::writeEmulatedFunctionTriplet(TInfoSinkBase &out,
3579                                               Visit visit,
3580                                               const TFunction *function)
3581 {
3582     if (visit == PreVisit)
3583     {
3584         ASSERT(function != nullptr);
3585         BuiltInFunctionEmulator::WriteEmulatedFunctionName(out, function->name().data());
3586         out << "(";
3587     }
3588     else
3589     {
3590         outputTriplet(out, visit, nullptr, ", ", ")");
3591     }
3592 }
3593 
writeSameSymbolInitializer(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * expression)3594 bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out,
3595                                             TIntermSymbol *symbolNode,
3596                                             TIntermTyped *expression)
3597 {
3598     ASSERT(symbolNode->variable().symbolType() != SymbolType::Empty);
3599     const TIntermSymbol *symbolInInitializer = FindSymbolNode(expression, symbolNode->getName());
3600 
3601     if (symbolInInitializer)
3602     {
3603         // Type already printed
3604         out << "t" + str(mUniqueIndex) + " = ";
3605         expression->traverse(this);
3606         out << ", ";
3607         symbolNode->traverse(this);
3608         out << " = t" + str(mUniqueIndex);
3609 
3610         mUniqueIndex++;
3611         return true;
3612     }
3613 
3614     return false;
3615 }
3616 
writeConstantInitialization(TInfoSinkBase & out,TIntermSymbol * symbolNode,TIntermTyped * initializer)3617 bool OutputHLSL::writeConstantInitialization(TInfoSinkBase &out,
3618                                              TIntermSymbol *symbolNode,
3619                                              TIntermTyped *initializer)
3620 {
3621     if (initializer->hasConstantValue())
3622     {
3623         symbolNode->traverse(this);
3624         out << ArrayString(symbolNode->getType());
3625         out << " = {";
3626         writeConstantUnionArray(out, initializer->getConstantValue(),
3627                                 initializer->getType().getObjectSize());
3628         out << "}";
3629         return true;
3630     }
3631     return false;
3632 }
3633 
addStructEqualityFunction(const TStructure & structure)3634 TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
3635 {
3636     const TFieldList &fields = structure.fields();
3637 
3638     for (const auto &eqFunction : mStructEqualityFunctions)
3639     {
3640         if (eqFunction->structure == &structure)
3641         {
3642             return eqFunction->functionName;
3643         }
3644     }
3645 
3646     const TString &structNameString = StructNameString(structure);
3647 
3648     StructEqualityFunction *function = new StructEqualityFunction();
3649     function->structure              = &structure;
3650     function->functionName           = "angle_eq_" + structNameString;
3651 
3652     TInfoSinkBase fnOut;
3653 
3654     fnOut << "bool " << function->functionName << "(" << structNameString << " a, "
3655           << structNameString + " b)\n"
3656           << "{\n"
3657              "    return ";
3658 
3659     for (size_t i = 0; i < fields.size(); i++)
3660     {
3661         const TField *field    = fields[i];
3662         const TType *fieldType = field->type();
3663 
3664         const TString &fieldNameA = "a." + Decorate(field->name());
3665         const TString &fieldNameB = "b." + Decorate(field->name());
3666 
3667         if (i > 0)
3668         {
3669             fnOut << " && ";
3670         }
3671 
3672         fnOut << "(";
3673         outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
3674         fnOut << fieldNameA;
3675         outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
3676         fnOut << fieldNameB;
3677         outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
3678         fnOut << ")";
3679     }
3680 
3681     fnOut << ";\n" << "}\n";
3682 
3683     function->functionDefinition = fnOut.c_str();
3684 
3685     mStructEqualityFunctions.push_back(function);
3686     mEqualityFunctions.push_back(function);
3687 
3688     return function->functionName;
3689 }
3690 
addArrayEqualityFunction(const TType & type)3691 TString OutputHLSL::addArrayEqualityFunction(const TType &type)
3692 {
3693     for (const auto &eqFunction : mArrayEqualityFunctions)
3694     {
3695         if (eqFunction->type == type)
3696         {
3697             return eqFunction->functionName;
3698         }
3699     }
3700 
3701     TType elementType(type);
3702     elementType.toArrayElementType();
3703 
3704     ArrayHelperFunction *function = new ArrayHelperFunction();
3705     function->type                = type;
3706 
3707     function->functionName = ArrayHelperFunctionName("angle_eq", type);
3708 
3709     TInfoSinkBase fnOut;
3710 
3711     const TString &typeName = TypeString(type);
3712     fnOut << "bool " << function->functionName << "(" << typeName << " a" << ArrayString(type)
3713           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3714           << "{\n"
3715              "    for (int i = 0; i < "
3716           << type.getOutermostArraySize()
3717           << "; ++i)\n"
3718              "    {\n"
3719              "        if (";
3720 
3721     outputEqual(PreVisit, elementType, EOpNotEqual, fnOut);
3722     fnOut << "a[i]";
3723     outputEqual(InVisit, elementType, EOpNotEqual, fnOut);
3724     fnOut << "b[i]";
3725     outputEqual(PostVisit, elementType, EOpNotEqual, fnOut);
3726 
3727     fnOut << ") { return false; }\n"
3728              "    }\n"
3729              "    return true;\n"
3730              "}\n";
3731 
3732     function->functionDefinition = fnOut.c_str();
3733 
3734     mArrayEqualityFunctions.push_back(function);
3735     mEqualityFunctions.push_back(function);
3736 
3737     return function->functionName;
3738 }
3739 
addArrayAssignmentFunction(const TType & type)3740 TString OutputHLSL::addArrayAssignmentFunction(const TType &type)
3741 {
3742     for (const auto &assignFunction : mArrayAssignmentFunctions)
3743     {
3744         if (assignFunction.type == type)
3745         {
3746             return assignFunction.functionName;
3747         }
3748     }
3749 
3750     TType elementType(type);
3751     elementType.toArrayElementType();
3752 
3753     ArrayHelperFunction function;
3754     function.type = type;
3755 
3756     function.functionName = ArrayHelperFunctionName("angle_assign", type);
3757 
3758     TInfoSinkBase fnOut;
3759 
3760     const TString &typeName = TypeString(type);
3761     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type)
3762           << ", " << typeName << " b" << ArrayString(type) << ")\n"
3763           << "{\n"
3764              "    for (int i = 0; i < "
3765           << type.getOutermostArraySize()
3766           << "; ++i)\n"
3767              "    {\n"
3768              "        ";
3769 
3770     outputAssign(PreVisit, elementType, fnOut);
3771     fnOut << "a[i]";
3772     outputAssign(InVisit, elementType, fnOut);
3773     fnOut << "b[i]";
3774     outputAssign(PostVisit, elementType, fnOut);
3775 
3776     fnOut << ";\n"
3777              "    }\n"
3778              "}\n";
3779 
3780     function.functionDefinition = fnOut.c_str();
3781 
3782     mArrayAssignmentFunctions.push_back(function);
3783 
3784     return function.functionName;
3785 }
3786 
addArrayConstructIntoFunction(const TType & type)3787 TString OutputHLSL::addArrayConstructIntoFunction(const TType &type)
3788 {
3789     for (const auto &constructIntoFunction : mArrayConstructIntoFunctions)
3790     {
3791         if (constructIntoFunction.type == type)
3792         {
3793             return constructIntoFunction.functionName;
3794         }
3795     }
3796 
3797     TType elementType(type);
3798     elementType.toArrayElementType();
3799 
3800     ArrayHelperFunction function;
3801     function.type = type;
3802 
3803     function.functionName = ArrayHelperFunctionName("angle_construct_into", type);
3804 
3805     TInfoSinkBase fnOut;
3806 
3807     const TString &typeName = TypeString(type);
3808     fnOut << "void " << function.functionName << "(out " << typeName << " a" << ArrayString(type);
3809     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3810     {
3811         fnOut << ", " << typeName << " b" << i << ArrayString(elementType);
3812     }
3813     fnOut << ")\n"
3814              "{\n";
3815 
3816     for (unsigned int i = 0u; i < type.getOutermostArraySize(); ++i)
3817     {
3818         fnOut << "    ";
3819         outputAssign(PreVisit, elementType, fnOut);
3820         fnOut << "a[" << i << "]";
3821         outputAssign(InVisit, elementType, fnOut);
3822         fnOut << "b" << i;
3823         outputAssign(PostVisit, elementType, fnOut);
3824         fnOut << ";\n";
3825     }
3826     fnOut << "}\n";
3827 
3828     function.functionDefinition = fnOut.c_str();
3829 
3830     mArrayConstructIntoFunctions.push_back(function);
3831 
3832     return function.functionName;
3833 }
3834 
addFlatEvaluateFunction(const TType & type,const TType & parameterType)3835 TString OutputHLSL::addFlatEvaluateFunction(const TType &type, const TType &parameterType)
3836 {
3837     for (const auto &flatEvaluateFunction : mFlatEvaluateFunctions)
3838     {
3839         if (flatEvaluateFunction.type == type &&
3840             flatEvaluateFunction.parameterType == parameterType)
3841         {
3842             return flatEvaluateFunction.functionName;
3843         }
3844     }
3845 
3846     FlatEvaluateFunction function;
3847     function.type          = type;
3848     function.parameterType = parameterType;
3849 
3850     const TString &typeName          = TypeString(type);
3851     const TString &parameterTypeName = TypeString(parameterType);
3852 
3853     function.functionName = "angle_eval_flat_" + typeName + "_" + parameterTypeName;
3854 
3855     // If <interpolant> is declared with a "flat" qualifier, the interpolated
3856     // value will have the same value everywhere for a single primitive, so
3857     // the location used for the interpolation has no effect and the functions
3858     // just return that same value.
3859     TInfoSinkBase fnOut;
3860     fnOut << typeName << " " << function.functionName << "(" << typeName << " i, "
3861           << parameterTypeName << " p)\n";
3862     fnOut << "{\n" << "    return i;\n" << "}\n";
3863     function.functionDefinition = fnOut.c_str();
3864 
3865     mFlatEvaluateFunctions.push_back(function);
3866 
3867     return function.functionName;
3868 }
3869 
ensureStructDefined(const TType & type)3870 void OutputHLSL::ensureStructDefined(const TType &type)
3871 {
3872     const TStructure *structure = type.getStruct();
3873     if (structure)
3874     {
3875         ASSERT(type.getBasicType() == EbtStruct);
3876         mStructureHLSL->ensureStructDefined(*structure);
3877     }
3878 }
3879 
shaderNeedsGenerateOutput() const3880 bool OutputHLSL::shaderNeedsGenerateOutput() const
3881 {
3882     return mShaderType == GL_VERTEX_SHADER || mShaderType == GL_FRAGMENT_SHADER;
3883 }
3884 
generateOutputCall() const3885 const char *OutputHLSL::generateOutputCall() const
3886 {
3887     if (mShaderType == GL_VERTEX_SHADER)
3888     {
3889         return "generateOutput(input)";
3890     }
3891     else
3892     {
3893         return "generateOutput()";
3894     }
3895 }
3896 }  // namespace sh
3897