xref: /aosp_15_r20/external/angle/src/compiler/translator/hlsl/ShaderStorageBlockOutputHLSL.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // ShaderStorageBlockOutputHLSL: A traverser to translate a ssbo_access_chain to an offset of
7 // RWByteAddressBuffer.
8 //     //EOpIndexDirectInterfaceBlock
9 //     ssbo_variable :=
10 //       | the name of the SSBO
11 //       | the name of a variable in an SSBO backed interface block
12 
13 //     // EOpIndexInDirect
14 //     // EOpIndexDirect
15 //     ssbo_array_indexing := ssbo_access_chain[expr_no_ssbo]
16 
17 //     // EOpIndexDirectStruct
18 //     ssbo_structure_access := ssbo_access_chain.identifier
19 
20 //     ssbo_access_chain :=
21 //       | ssbo_variable
22 //       | ssbo_array_indexing
23 //       | ssbo_structure_access
24 //
25 
26 #include "compiler/translator/hlsl/ShaderStorageBlockOutputHLSL.h"
27 
28 #include "compiler/translator/hlsl/ResourcesHLSL.h"
29 #include "compiler/translator/hlsl/blocklayoutHLSL.h"
30 #include "compiler/translator/tree_util/IntermNode_util.h"
31 #include "compiler/translator/util.h"
32 
33 namespace sh
34 {
35 
36 namespace
37 {
38 
39 constexpr const char kShaderStorageDeclarationString[] =
40     "// @@ SHADER STORAGE DECLARATION STRING @@";
41 
GetBlockLayoutInfo(TIntermTyped * node,bool rowMajorAlreadyAssigned,TLayoutBlockStorage * storage,bool * rowMajor)42 void GetBlockLayoutInfo(TIntermTyped *node,
43                         bool rowMajorAlreadyAssigned,
44                         TLayoutBlockStorage *storage,
45                         bool *rowMajor)
46 {
47     TIntermSwizzle *swizzleNode = node->getAsSwizzleNode();
48     if (swizzleNode)
49     {
50         return GetBlockLayoutInfo(swizzleNode->getOperand(), rowMajorAlreadyAssigned, storage,
51                                   rowMajor);
52     }
53 
54     TIntermBinary *binaryNode = node->getAsBinaryNode();
55     if (binaryNode)
56     {
57         switch (binaryNode->getOp())
58         {
59             case EOpIndexDirectInterfaceBlock:
60             {
61                 // The column_major/row_major qualifier of a field member overrides the interface
62                 // block's row_major/column_major. So we can assign rowMajor here and don't need to
63                 // assign it again. But we still need to call recursively to get the storage's
64                 // value.
65                 const TType &type = node->getType();
66                 *rowMajor         = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
67                 return GetBlockLayoutInfo(binaryNode->getLeft(), true, storage, rowMajor);
68             }
69             case EOpIndexIndirect:
70             case EOpIndexDirect:
71             case EOpIndexDirectStruct:
72                 return GetBlockLayoutInfo(binaryNode->getLeft(), rowMajorAlreadyAssigned, storage,
73                                           rowMajor);
74             default:
75                 UNREACHABLE();
76                 return;
77         }
78     }
79 
80     const TType &type = node->getType();
81     ASSERT(type.getQualifier() == EvqBuffer);
82     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
83     ASSERT(interfaceBlock);
84     *storage = interfaceBlock->blockStorage();
85     // If the block doesn't have an instance name, rowMajorAlreadyAssigned will be false. In
86     // this situation, we still need to set rowMajor's value.
87     if (!rowMajorAlreadyAssigned)
88     {
89         *rowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
90     }
91 }
92 
93 // It's possible that the current type has lost the original layout information. So we should pass
94 // the right layout information to GetBlockMemberInfoByType.
GetBlockMemberInfoByType(const TType & type,TLayoutBlockStorage storage,bool rowMajor)95 const BlockMemberInfo GetBlockMemberInfoByType(const TType &type,
96                                                TLayoutBlockStorage storage,
97                                                bool rowMajor)
98 {
99     sh::Std140BlockEncoder std140Encoder;
100     sh::Std430BlockEncoder std430Encoder;
101     sh::HLSLBlockEncoder hlslEncoder(sh::HLSLBlockEncoder::ENCODE_PACKED, false);
102     sh::BlockLayoutEncoder *encoder = nullptr;
103 
104     if (storage == EbsStd140)
105     {
106         encoder = &std140Encoder;
107     }
108     else if (storage == EbsStd430)
109     {
110         encoder = &std430Encoder;
111     }
112     else
113     {
114         encoder = &hlslEncoder;
115     }
116 
117     std::vector<unsigned int> arraySizes;
118     const TSpan<const unsigned int> &typeArraySizes = type.getArraySizes();
119     if (!typeArraySizes.empty())
120     {
121         arraySizes.assign(typeArraySizes.begin(), typeArraySizes.end());
122     }
123     return encoder->encodeType(GLVariableType(type), arraySizes, rowMajor);
124 }
125 
GetFieldMemberInShaderStorageBlock(const TInterfaceBlock * interfaceBlock,const ImmutableString & variableName)126 const TField *GetFieldMemberInShaderStorageBlock(const TInterfaceBlock *interfaceBlock,
127                                                  const ImmutableString &variableName)
128 {
129     for (const TField *field : interfaceBlock->fields())
130     {
131         if (field->name() == variableName)
132         {
133             return field;
134         }
135     }
136     return nullptr;
137 }
138 
FindInterfaceBlock(const TInterfaceBlock * needle,const std::vector<InterfaceBlock> & haystack)139 const InterfaceBlock *FindInterfaceBlock(const TInterfaceBlock *needle,
140                                          const std::vector<InterfaceBlock> &haystack)
141 {
142     for (const InterfaceBlock &block : haystack)
143     {
144         if (strcmp(block.name.c_str(), needle->name().data()) == 0)
145         {
146             ASSERT(block.fields.size() == needle->fields().size());
147             return &block;
148         }
149     }
150 
151     UNREACHABLE();
152     return nullptr;
153 }
154 
StripArrayIndices(const std::string & nameIn)155 std::string StripArrayIndices(const std::string &nameIn)
156 {
157     std::string name = nameIn;
158     size_t pos       = name.find('[');
159     while (pos != std::string::npos)
160     {
161         size_t closePos = name.find(']', pos);
162         ASSERT(closePos != std::string::npos);
163         name.erase(pos, closePos - pos + 1);
164         pos = name.find('[', pos);
165     }
166     ASSERT(name.find(']') == std::string::npos);
167     return name;
168 }
169 
170 // Does not include any array indices.
MapVariableToField(const ShaderVariable & variable,const TField * field,std::string currentName,ShaderVarToFieldMap * shaderVarToFieldMap)171 void MapVariableToField(const ShaderVariable &variable,
172                         const TField *field,
173                         std::string currentName,
174                         ShaderVarToFieldMap *shaderVarToFieldMap)
175 {
176     ASSERT((field->type()->getStruct() == nullptr) == variable.fields.empty());
177     (*shaderVarToFieldMap)[currentName] = field;
178 
179     if (!variable.fields.empty())
180     {
181         const TStructure *subStruct = field->type()->getStruct();
182         ASSERT(variable.fields.size() == subStruct->fields().size());
183 
184         for (size_t index = 0; index < variable.fields.size(); ++index)
185         {
186             const TField *subField            = subStruct->fields()[index];
187             const ShaderVariable &subVariable = variable.fields[index];
188             std::string subName               = currentName + "." + subVariable.name;
189             MapVariableToField(subVariable, subField, subName, shaderVarToFieldMap);
190         }
191     }
192 }
193 
194 class BlockInfoVisitor final : public BlockEncoderVisitor
195 {
196   public:
BlockInfoVisitor(const std::string & prefix,TLayoutBlockStorage storage,const ShaderVarToFieldMap & shaderVarToFieldMap,BlockMemberInfoMap * blockInfoOut)197     BlockInfoVisitor(const std::string &prefix,
198                      TLayoutBlockStorage storage,
199                      const ShaderVarToFieldMap &shaderVarToFieldMap,
200                      BlockMemberInfoMap *blockInfoOut)
201         : BlockEncoderVisitor(prefix, "", getEncoder(storage)),
202           mShaderVarToFieldMap(shaderVarToFieldMap),
203           mBlockInfoOut(blockInfoOut),
204           mHLSLEncoder(HLSLBlockEncoder::ENCODE_PACKED, false),
205           mStorage(storage)
206     {}
207 
getEncoder(TLayoutBlockStorage storage)208     BlockLayoutEncoder *getEncoder(TLayoutBlockStorage storage)
209     {
210         switch (storage)
211         {
212             case EbsStd140:
213                 return &mStd140Encoder;
214             case EbsStd430:
215                 return &mStd430Encoder;
216             default:
217                 return &mHLSLEncoder;
218         }
219     }
220 
enterStructAccess(const ShaderVariable & structVar,bool isRowMajor)221     void enterStructAccess(const ShaderVariable &structVar, bool isRowMajor) override
222     {
223         BlockEncoderVisitor::enterStructAccess(structVar, isRowMajor);
224 
225         std::string variableName = StripArrayIndices(collapseNameStack());
226 
227         // Remove the trailing "."
228         variableName.pop_back();
229 
230         BlockInfoVisitor childVisitor(variableName, mStorage, mShaderVarToFieldMap, mBlockInfoOut);
231         childVisitor.getEncoder(mStorage)->enterAggregateType(structVar);
232         TraverseShaderVariables(structVar.fields, isRowMajor, &childVisitor);
233         childVisitor.getEncoder(mStorage)->exitAggregateType(structVar);
234 
235         int offset      = static_cast<int>(getEncoder(mStorage)->getCurrentOffset());
236         int arrayStride = static_cast<int>(childVisitor.getEncoder(mStorage)->getCurrentOffset());
237 
238         auto iter = mShaderVarToFieldMap.find(variableName);
239         if (iter == mShaderVarToFieldMap.end())
240             return;
241 
242         const TField *structField = iter->second;
243         if (mBlockInfoOut->count(structField) == 0)
244         {
245             mBlockInfoOut->emplace(structField, BlockMemberInfo(offset, arrayStride, -1, false));
246         }
247     }
248 
encodeVariable(const ShaderVariable & variable,const BlockMemberInfo & variableInfo,const std::string & name,const std::string & mappedName)249     void encodeVariable(const ShaderVariable &variable,
250                         const BlockMemberInfo &variableInfo,
251                         const std::string &name,
252                         const std::string &mappedName) override
253     {
254         auto iter = mShaderVarToFieldMap.find(StripArrayIndices(name));
255         if (iter == mShaderVarToFieldMap.end())
256             return;
257 
258         const TField *field = iter->second;
259         if (mBlockInfoOut->count(field) == 0)
260         {
261             mBlockInfoOut->emplace(field, variableInfo);
262         }
263     }
264 
265   private:
266     const ShaderVarToFieldMap &mShaderVarToFieldMap;
267     BlockMemberInfoMap *mBlockInfoOut;
268     Std140BlockEncoder mStd140Encoder;
269     Std430BlockEncoder mStd430Encoder;
270     HLSLBlockEncoder mHLSLEncoder;
271     TLayoutBlockStorage mStorage;
272 };
273 
GetShaderStorageBlockMembersInfo(const TInterfaceBlock * interfaceBlock,const std::vector<InterfaceBlock> & shaderStorageBlocks,BlockMemberInfoMap * blockInfoOut)274 void GetShaderStorageBlockMembersInfo(const TInterfaceBlock *interfaceBlock,
275                                       const std::vector<InterfaceBlock> &shaderStorageBlocks,
276                                       BlockMemberInfoMap *blockInfoOut)
277 {
278     // Find the sh::InterfaceBlock.
279     const InterfaceBlock *block = FindInterfaceBlock(interfaceBlock, shaderStorageBlocks);
280     ASSERT(block);
281 
282     // Map ShaderVariable to TField.
283     ShaderVarToFieldMap shaderVarToFieldMap;
284     for (size_t index = 0; index < block->fields.size(); ++index)
285     {
286         const TField *field            = interfaceBlock->fields()[index];
287         const ShaderVariable &variable = block->fields[index];
288         MapVariableToField(variable, field, variable.name, &shaderVarToFieldMap);
289     }
290 
291     BlockInfoVisitor visitor("", interfaceBlock->blockStorage(), shaderVarToFieldMap, blockInfoOut);
292     TraverseShaderVariables(block->fields, false, &visitor);
293 }
294 
Mul(TIntermTyped * left,TIntermTyped * right)295 TIntermTyped *Mul(TIntermTyped *left, TIntermTyped *right)
296 {
297     return left && right ? new TIntermBinary(EOpMul, left, right) : nullptr;
298 }
299 
Add(TIntermTyped * left,TIntermTyped * right)300 TIntermTyped *Add(TIntermTyped *left, TIntermTyped *right)
301 {
302     return left ? right ? new TIntermBinary(EOpAdd, left, right) : left : right;
303 }
304 
305 }  // anonymous namespace
306 
ShaderStorageBlockOutputHLSL(OutputHLSL * outputHLSL,ResourcesHLSL * resourcesHLSL,const std::vector<InterfaceBlock> & shaderStorageBlocks)307 ShaderStorageBlockOutputHLSL::ShaderStorageBlockOutputHLSL(
308     OutputHLSL *outputHLSL,
309     ResourcesHLSL *resourcesHLSL,
310     const std::vector<InterfaceBlock> &shaderStorageBlocks)
311     : mOutputHLSL(outputHLSL),
312       mResourcesHLSL(resourcesHLSL),
313       mShaderStorageBlocks(shaderStorageBlocks)
314 {
315     mSSBOFunctionHLSL = new ShaderStorageBlockFunctionHLSL;
316 }
317 
~ShaderStorageBlockOutputHLSL()318 ShaderStorageBlockOutputHLSL::~ShaderStorageBlockOutputHLSL()
319 {
320     SafeDelete(mSSBOFunctionHLSL);
321 }
322 
outputStoreFunctionCallPrefix(TIntermTyped * node)323 void ShaderStorageBlockOutputHLSL::outputStoreFunctionCallPrefix(TIntermTyped *node)
324 {
325     traverseSSBOAccess(node, SSBOMethod::STORE);
326 }
327 
outputLoadFunctionCall(TIntermTyped * node)328 void ShaderStorageBlockOutputHLSL::outputLoadFunctionCall(TIntermTyped *node)
329 {
330     traverseSSBOAccess(node, SSBOMethod::LOAD);
331     mOutputHLSL->getInfoSink() << ")";
332 }
333 
outputLengthFunctionCall(TIntermTyped * node)334 void ShaderStorageBlockOutputHLSL::outputLengthFunctionCall(TIntermTyped *node)
335 {
336     traverseSSBOAccess(node, SSBOMethod::LENGTH);
337     mOutputHLSL->getInfoSink() << ")";
338 }
339 
outputAtomicMemoryFunctionCallPrefix(TIntermTyped * node,TOperator op)340 void ShaderStorageBlockOutputHLSL::outputAtomicMemoryFunctionCallPrefix(TIntermTyped *node,
341                                                                         TOperator op)
342 {
343     switch (op)
344     {
345         case EOpAtomicAdd:
346             traverseSSBOAccess(node, SSBOMethod::ATOMIC_ADD);
347             break;
348         case EOpAtomicMin:
349             traverseSSBOAccess(node, SSBOMethod::ATOMIC_MIN);
350             break;
351         case EOpAtomicMax:
352             traverseSSBOAccess(node, SSBOMethod::ATOMIC_MAX);
353             break;
354         case EOpAtomicAnd:
355             traverseSSBOAccess(node, SSBOMethod::ATOMIC_AND);
356             break;
357         case EOpAtomicOr:
358             traverseSSBOAccess(node, SSBOMethod::ATOMIC_OR);
359             break;
360         case EOpAtomicXor:
361             traverseSSBOAccess(node, SSBOMethod::ATOMIC_XOR);
362             break;
363         case EOpAtomicExchange:
364             traverseSSBOAccess(node, SSBOMethod::ATOMIC_EXCHANGE);
365             break;
366         case EOpAtomicCompSwap:
367             traverseSSBOAccess(node, SSBOMethod::ATOMIC_COMPSWAP);
368             break;
369         default:
370             UNREACHABLE();
371             break;
372     }
373 }
374 
375 // Note that we must calculate the matrix stride here instead of ShaderStorageBlockFunctionHLSL.
376 // It's because that if the current node's type is a vector which comes from a matrix, we will
377 // lose the matrix type info once we enter ShaderStorageBlockFunctionHLSL.
getMatrixStride(TIntermTyped * node,TLayoutBlockStorage storage,bool rowMajor,bool * isRowMajorMatrix) const378 int ShaderStorageBlockOutputHLSL::getMatrixStride(TIntermTyped *node,
379                                                   TLayoutBlockStorage storage,
380                                                   bool rowMajor,
381                                                   bool *isRowMajorMatrix) const
382 {
383     if (node->getType().isMatrix())
384     {
385         *isRowMajorMatrix = rowMajor;
386         return GetBlockMemberInfoByType(node->getType(), storage, rowMajor).matrixStride;
387     }
388 
389     if (node->getType().isVector())
390     {
391         TIntermBinary *binaryNode = node->getAsBinaryNode();
392         if (binaryNode)
393         {
394             return getMatrixStride(binaryNode->getLeft(), storage, rowMajor, isRowMajorMatrix);
395         }
396         else
397         {
398             TIntermSwizzle *swizzleNode = node->getAsSwizzleNode();
399             if (swizzleNode)
400             {
401                 return getMatrixStride(swizzleNode->getOperand(), storage, rowMajor,
402                                        isRowMajorMatrix);
403             }
404         }
405     }
406     return 0;
407 }
408 
collectShaderStorageBlocks(TIntermTyped * node)409 void ShaderStorageBlockOutputHLSL::collectShaderStorageBlocks(TIntermTyped *node)
410 {
411     TIntermSwizzle *swizzleNode = node->getAsSwizzleNode();
412     if (swizzleNode)
413     {
414         return collectShaderStorageBlocks(swizzleNode->getOperand());
415     }
416 
417     TIntermBinary *binaryNode = node->getAsBinaryNode();
418     if (binaryNode)
419     {
420         switch (binaryNode->getOp())
421         {
422             case EOpIndexDirectInterfaceBlock:
423             case EOpIndexIndirect:
424             case EOpIndexDirect:
425             case EOpIndexDirectStruct:
426                 return collectShaderStorageBlocks(binaryNode->getLeft());
427             default:
428                 UNREACHABLE();
429                 return;
430         }
431     }
432 
433     const TIntermSymbol *symbolNode = node->getAsSymbolNode();
434     const TType &type               = symbolNode->getType();
435     ASSERT(type.getQualifier() == EvqBuffer);
436     const TVariable &variable = symbolNode->variable();
437 
438     const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
439     ASSERT(interfaceBlock);
440     if (mReferencedShaderStorageBlocks.count(interfaceBlock->uniqueId().get()) == 0)
441     {
442         const TVariable *instanceVariable = nullptr;
443         if (type.isInterfaceBlock())
444         {
445             instanceVariable = &variable;
446         }
447         mReferencedShaderStorageBlocks[interfaceBlock->uniqueId().get()] =
448             new TReferencedBlock(interfaceBlock, instanceVariable);
449         GetShaderStorageBlockMembersInfo(interfaceBlock, mShaderStorageBlocks,
450                                          &mBlockMemberInfoMap);
451     }
452 }
453 
traverseSSBOAccess(TIntermTyped * node,SSBOMethod method)454 void ShaderStorageBlockOutputHLSL::traverseSSBOAccess(TIntermTyped *node, SSBOMethod method)
455 {
456     // TODO: Merge collectShaderStorageBlocks and GetBlockLayoutInfo to simplify the code.
457     collectShaderStorageBlocks(node);
458 
459     // Note that we don't have correct BlockMemberInfo from mBlockMemberInfoMap at the current
460     // point. But we must use those information to generate the right function name. So here we have
461     // to calculate them again.
462     TLayoutBlockStorage storage;
463     bool rowMajor;
464     GetBlockLayoutInfo(node, false, &storage, &rowMajor);
465     int unsizedArrayStride = 0;
466     if (node->getType().isUnsizedArray())
467     {
468         // The unsized array member must be the last member of a shader storage block.
469         TIntermBinary *binaryNode = node->getAsBinaryNode();
470         if (binaryNode)
471         {
472             const TInterfaceBlock *interfaceBlock =
473                 binaryNode->getLeft()->getType().getInterfaceBlock();
474             ASSERT(interfaceBlock);
475             const TIntermConstantUnion *index = binaryNode->getRight()->getAsConstantUnion();
476             const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
477             auto fieldInfoIter                = mBlockMemberInfoMap.find(field);
478             ASSERT(fieldInfoIter != mBlockMemberInfoMap.end());
479             unsizedArrayStride = fieldInfoIter->second.arrayStride;
480         }
481         else
482         {
483             const TIntermSymbol *symbolNode       = node->getAsSymbolNode();
484             const TVariable &variable             = symbolNode->variable();
485             const TInterfaceBlock *interfaceBlock = symbolNode->getType().getInterfaceBlock();
486             ASSERT(interfaceBlock);
487             const TField *field =
488                 GetFieldMemberInShaderStorageBlock(interfaceBlock, variable.name());
489             auto fieldInfoIter = mBlockMemberInfoMap.find(field);
490             ASSERT(fieldInfoIter != mBlockMemberInfoMap.end());
491             unsizedArrayStride = fieldInfoIter->second.arrayStride;
492         }
493     }
494     bool isRowMajorMatrix = false;
495     int matrixStride      = getMatrixStride(node, storage, rowMajor, &isRowMajorMatrix);
496 
497     const TString &functionName = mSSBOFunctionHLSL->registerShaderStorageBlockFunction(
498         node->getType(), method, storage, isRowMajorMatrix, matrixStride, unsizedArrayStride,
499         node->getAsSwizzleNode());
500     TInfoSinkBase &out = mOutputHLSL->getInfoSink();
501     out << functionName;
502     out << "(";
503     BlockMemberInfo blockMemberInfo;
504     TIntermNode *loc = traverseNode(out, node, &blockMemberInfo);
505     out << ", ";
506     loc->traverse(mOutputHLSL);
507 }
508 
writeShaderStorageBlocksHeader(GLenum shaderType,TInfoSinkBase & out) const509 void ShaderStorageBlockOutputHLSL::writeShaderStorageBlocksHeader(GLenum shaderType,
510                                                                   TInfoSinkBase &out) const
511 {
512     if (mReferencedShaderStorageBlocks.empty())
513     {
514         return;
515     }
516 
517     mResourcesHLSL->allocateShaderStorageBlockRegisters(mReferencedShaderStorageBlocks);
518     out << "// Shader Storage Blocks\n\n";
519     if (shaderType == GL_COMPUTE_SHADER)
520     {
521         out << mResourcesHLSL->shaderStorageBlocksHeader(mReferencedShaderStorageBlocks);
522     }
523     else
524     {
525         out << kShaderStorageDeclarationString << "\n";
526     }
527     mSSBOFunctionHLSL->shaderStorageBlockFunctionHeader(out);
528 }
529 
traverseNode(TInfoSinkBase & out,TIntermTyped * node,BlockMemberInfo * blockMemberInfo)530 TIntermTyped *ShaderStorageBlockOutputHLSL::traverseNode(TInfoSinkBase &out,
531                                                          TIntermTyped *node,
532                                                          BlockMemberInfo *blockMemberInfo)
533 {
534     if (TIntermSymbol *symbolNode = node->getAsSymbolNode())
535     {
536         const TVariable &variable = symbolNode->variable();
537         const TType &type         = variable.getType();
538         if (type.isInterfaceBlock())
539         {
540             out << DecorateVariableIfNeeded(variable);
541         }
542         else
543         {
544             const TInterfaceBlock *interfaceBlock = type.getInterfaceBlock();
545             out << Decorate(interfaceBlock->name());
546             const TField *field =
547                 GetFieldMemberInShaderStorageBlock(interfaceBlock, variable.name());
548             return createFieldOffset(field, blockMemberInfo);
549         }
550     }
551     else if (TIntermSwizzle *swizzleNode = node->getAsSwizzleNode())
552     {
553         return traverseNode(out, swizzleNode->getOperand(), blockMemberInfo);
554     }
555     else if (TIntermBinary *binaryNode = node->getAsBinaryNode())
556     {
557         switch (binaryNode->getOp())
558         {
559             case EOpIndexDirect:
560             {
561                 const TType &leftType = binaryNode->getLeft()->getType();
562                 if (leftType.isInterfaceBlock())
563                 {
564                     ASSERT(leftType.getQualifier() == EvqBuffer);
565                     TIntermSymbol *instanceArraySymbol = binaryNode->getLeft()->getAsSymbolNode();
566 
567                     const int arrayIndex =
568                         binaryNode->getRight()->getAsConstantUnion()->getIConst(0);
569                     out << mResourcesHLSL->InterfaceBlockInstanceString(
570                         instanceArraySymbol->getName(), arrayIndex);
571                 }
572                 else
573                 {
574                     return writeEOpIndexDirectOrIndirectOutput(out, binaryNode, blockMemberInfo);
575                 }
576                 break;
577             }
578             case EOpIndexIndirect:
579             {
580                 // We do not currently support indirect references to interface blocks
581                 ASSERT(binaryNode->getLeft()->getBasicType() != EbtInterfaceBlock);
582                 return writeEOpIndexDirectOrIndirectOutput(out, binaryNode, blockMemberInfo);
583             }
584             case EOpIndexDirectStruct:
585             {
586                 // We do not currently support direct references to interface blocks
587                 ASSERT(binaryNode->getLeft()->getBasicType() != EbtInterfaceBlock);
588                 TIntermTyped *left = traverseNode(out, binaryNode->getLeft(), blockMemberInfo);
589                 const TStructure *structure       = binaryNode->getLeft()->getType().getStruct();
590                 const TIntermConstantUnion *index = binaryNode->getRight()->getAsConstantUnion();
591                 const TField *field               = structure->fields()[index->getIConst(0)];
592                 return Add(createFieldOffset(field, blockMemberInfo), left);
593             }
594             case EOpIndexDirectInterfaceBlock:
595             {
596                 ASSERT(IsInShaderStorageBlock(binaryNode->getLeft()));
597                 traverseNode(out, binaryNode->getLeft(), blockMemberInfo);
598                 const TInterfaceBlock *interfaceBlock =
599                     binaryNode->getLeft()->getType().getInterfaceBlock();
600                 const TIntermConstantUnion *index = binaryNode->getRight()->getAsConstantUnion();
601                 const TField *field               = interfaceBlock->fields()[index->getIConst(0)];
602                 return createFieldOffset(field, blockMemberInfo);
603             }
604             default:
605                 return nullptr;
606         }
607     }
608     return nullptr;
609 }
610 
writeEOpIndexDirectOrIndirectOutput(TInfoSinkBase & out,TIntermBinary * node,BlockMemberInfo * blockMemberInfo)611 TIntermTyped *ShaderStorageBlockOutputHLSL::writeEOpIndexDirectOrIndirectOutput(
612     TInfoSinkBase &out,
613     TIntermBinary *node,
614     BlockMemberInfo *blockMemberInfo)
615 {
616     ASSERT(IsInShaderStorageBlock(node->getLeft()));
617     TIntermTyped *left  = traverseNode(out, node->getLeft(), blockMemberInfo);
618     TIntermTyped *right = node->getRight()->deepCopy();
619     const TType &type   = node->getLeft()->getType();
620     TLayoutBlockStorage storage;
621     bool rowMajor;
622     GetBlockLayoutInfo(node, false, &storage, &rowMajor);
623 
624     if (type.isArray())
625     {
626         const TSpan<const unsigned int> &arraySizes = type.getArraySizes();
627         for (unsigned int i = 0; i < arraySizes.size() - 1; i++)
628         {
629             right = Mul(CreateUIntNode(arraySizes[i]), right);
630         }
631         right = Mul(CreateUIntNode(blockMemberInfo->arrayStride), right);
632     }
633     else if (type.isMatrix())
634     {
635         if (rowMajor)
636         {
637             right = Mul(CreateUIntNode(BlockLayoutEncoder::kBytesPerComponent), right);
638         }
639         else
640         {
641             right = Mul(CreateUIntNode(blockMemberInfo->matrixStride), right);
642         }
643     }
644     else if (type.isVector())
645     {
646         if (blockMemberInfo->isRowMajorMatrix)
647         {
648             right = Mul(CreateUIntNode(blockMemberInfo->matrixStride), right);
649         }
650         else
651         {
652             right = Mul(CreateUIntNode(BlockLayoutEncoder::kBytesPerComponent), right);
653         }
654     }
655     return Add(left, right);
656 }
657 
createFieldOffset(const TField * field,BlockMemberInfo * blockMemberInfo)658 TIntermTyped *ShaderStorageBlockOutputHLSL::createFieldOffset(const TField *field,
659                                                               BlockMemberInfo *blockMemberInfo)
660 {
661     auto fieldInfoIter = mBlockMemberInfoMap.find(field);
662     ASSERT(fieldInfoIter != mBlockMemberInfoMap.end());
663     *blockMemberInfo = fieldInfoIter->second;
664     return CreateUIntNode(blockMemberInfo->offset);
665 }
666 
667 }  // namespace sh
668