xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/WorkloadData.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/backends/TensorHandle.hpp>
7 #include <armnn/backends/WorkloadData.hpp>
8 #include <armnn/backends/WorkloadInfo.hpp>
9 #include <armnnUtils/DataLayoutIndexed.hpp>
10 #include <armnnUtils/TensorUtils.hpp>
11 #include <armnnUtils/Permute.hpp>
12 #include <armnn/utility/NumericCast.hpp>
13 #include <armnn/Logging.hpp>
14 
15 #include <algorithm>
16 #include <iomanip>
17 #include <string>
18 #include <sstream>
19 
20 #include <fmt/format.h>
21 
22 using namespace armnnUtils;
23 
24 namespace armnn
25 {
26 
27 //---------------------------------------------------------------
GetBiasDataType(DataType inputDataType)28 DataType GetBiasDataType(DataType inputDataType)
29 {
30     switch (inputDataType)
31     {
32         case DataType::Float16:
33             return DataType::Float16;
34         case DataType::BFloat16:
35         case DataType::Float32:
36             return DataType::Float32;
37         case DataType::QAsymmS8:
38         case DataType::QAsymmU8:
39         case DataType::QSymmS8:
40         case DataType::QSymmS16:
41             return DataType::Signed32;
42         default:
43             ARMNN_ASSERT_MSG(false, "Invalid input data type");
44             return DataType::Float32;
45     }
46 }
47 
48 namespace
49 {
50 
51 //---------------------------------------------------------------
52 //android ndk does not support std::to_string function.
53 template <typename T>
to_string(T value)54 std::string to_string(T value)
55 {
56     std::ostringstream os;
57     os << value;
58     return os.str();
59 }
60 
61 //---------------------------------------------------------------
ValidatePointer(const void * ptr,std::string const & descName,std::string const & paramName)62 void ValidatePointer(const void* ptr, std::string const& descName, std::string const& paramName)
63 {
64     if (!ptr)
65     {
66         throw InvalidArgumentException(descName +  ": Invalid null pointer. The " +
67                                       paramName + " parameter must be set.");
68     }
69 }
70 
71 //---------------------------------------------------------------
ValidateTensorShapesMatch(const TensorInfo & first,const TensorInfo & second,std::string const & descName,std::string const & firstName,std::string const & secondName)72 void ValidateTensorShapesMatch(const TensorInfo& first,
73                                const TensorInfo& second,
74                                std::string const& descName,
75                                std::string const& firstName,
76                                std::string const& secondName)
77 {
78     if (first.GetShape() != second.GetShape())
79     {
80         throw InvalidArgumentException(descName + ": "
81                                        + firstName + " & " + secondName + " must have identical shapes");
82     }
83 }
84 
85 //---------------------------------------------------------------
ValidateNumInputs(const WorkloadInfo & workloadInfo,std::string const & descName,const unsigned int expectedSize)86 void ValidateNumInputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
87 {
88     if (workloadInfo.m_InputTensorInfos.size() != expectedSize)
89     {
90         throw InvalidArgumentException(descName +
91                                        ": Requires exactly " + to_string(expectedSize) + "input(s). " +
92                                        to_string(workloadInfo.m_InputTensorInfos.size()) + " have been provided.");
93     }
94 }
95 
96 //---------------------------------------------------------------
ValidateNumOutputs(const WorkloadInfo & workloadInfo,std::string const & descName,const unsigned int expectedSize)97 void ValidateNumOutputs(const WorkloadInfo& workloadInfo, std::string const& descName, const unsigned int expectedSize)
98 {
99     if (workloadInfo.m_OutputTensorInfos.size() != expectedSize)
100     {
101         throw InvalidArgumentException(descName +
102                                        ": Requires exactly " + to_string(expectedSize) + " output(s). " +
103                                        to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
104     }
105 }
106 
107 //---------------------------------------------------------------
108 
109 //---------------------------------------------------------------
ValidateTensorNumElements(const TensorInfo & tensor,std::string const & descName,unsigned int numElements,std::string const & tensorName)110 void ValidateTensorNumElements(const TensorInfo& tensor,
111                                std::string const& descName,
112                                unsigned int numElements,
113                                std::string const& tensorName)
114 {
115     if (tensor.GetNumElements() != numElements)
116     {
117         throw InvalidArgumentException(descName + ": Expected " + to_string(numElements) + " but got " +
118                                        to_string(tensor.GetNumElements()) + " elements for " +
119                                        tensorName + " tensor.");
120     }
121 }
122 
123 //---------------------------------------------------------------
ValidateTensorDataType(const TensorInfo & tensor,DataType dataType,const std::string & descName,std::string const & tensorName)124 void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
125     const std::string& descName, std::string const& tensorName)
126 {
127     if (tensor.GetDataType() != dataType)
128     {
129         throw InvalidArgumentException(descName + ": Expected data type " + GetDataTypeName(dataType) + " but got " +
130             GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
131     }
132 }
133 
ValidPerAxisQuantizedDataType(const TensorInfo & tensor,const std::string & descName,const std::string & tensorName)134 void ValidPerAxisQuantizedDataType(const TensorInfo& tensor, const std::string& descName, const std::string& tensorName)
135 {
136     if (tensor.GetDataType() != DataType::QSymmS8)
137     {
138         throw InvalidArgumentException(descName +
139             ": Expected data type which supports per-axis quantization scheme but got " +
140             GetDataTypeName(tensor.GetDataType()) + " for " + tensorName + " tensor.");
141     }
142 }
143 
144 //---------------------------------------------------------------
ValidateTensorQuantizationSpace(const TensorInfo & first,const TensorInfo & second,const std::string & descName,std::string const & firstName,std::string const & secondName)145 void ValidateTensorQuantizationSpace(const TensorInfo& first,
146                                      const TensorInfo& second,
147                                      const std::string& descName,
148                                      std::string const& firstName,
149                                      std::string const& secondName)
150 {
151     if (!first.IsQuantized() ||
152         !second.IsQuantized())
153     {
154         // Not a quantized type, ignore the validation
155         return;
156     }
157 
158     DataType firstDataType  = first.GetDataType();
159     DataType secondDataType = second.GetDataType();
160 
161     if (firstDataType != secondDataType)
162     {
163         throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
164                                        " must be of the same quantized type, " +
165                                        firstName + " is " + GetDataTypeName(firstDataType) + ", " +
166                                        secondName + " is " + GetDataTypeName(secondDataType));
167     }
168 
169     if (!first.IsTypeSpaceMatch(second))
170     {
171         throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
172                                        " must have the same quantization space, " +
173                                        firstName + " has offset " + to_string(first.GetQuantizationOffset()) +
174                                        " and scale " + to_string(first.GetQuantizationScale()) + ", " +
175                                        secondName + " has offset " + to_string(second.GetQuantizationOffset()) +
176                                        " and scale " + to_string(second.GetQuantizationScale()));
177     }
178 }
179 
180 //---------------------------------------------------------------
ValidateBiasTensorQuantization(const TensorInfo & biasTensor,const TensorInfo & inputTensorInfo,const TensorInfo & weightsTensorInfo,const std::string & descName)181 void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
182                                     const TensorInfo& inputTensorInfo,
183                                     const TensorInfo& weightsTensorInfo,
184                                     const std::string& descName)
185 {
186     // Helper lambda function to validate a single bias quantization scale value
187     auto VerifyBiasQuantizationScale = [&descName](float biasScale, float expectedScale) -> void
188     {
189         constexpr float tolerance = 0.0001f;
190         if (std::abs(biasScale - expectedScale) > tolerance)
191         {
192             // Print the float values with extra precision to see very small differences
193             ARMNN_LOG(warning) << std::setprecision(6) << descName << ": Expected " << expectedScale <<
194                 " for bias quantization scale (product of input and weight scales), but got " <<
195                 biasScale << ". Using scale provided.";
196         }
197     };
198 
199     if (biasTensor.GetQuantizationOffset() != 0)
200     {
201         throw InvalidArgumentException(descName + ": Expected zero quantization offset for bias tensor but got " +
202             to_string(biasTensor.GetQuantizationOffset()));
203     }
204 
205     if (biasTensor.HasMultipleQuantizationScales() || weightsTensorInfo.HasMultipleQuantizationScales())
206     {
207         // Validate per-axis quantization scales
208         const std::vector<float>& weightScales = weightsTensorInfo.GetQuantizationScales();
209         const std::vector<float>& biasScales   = biasTensor.GetQuantizationScales();
210 
211         if (weightScales.size() != biasScales.size())
212         {
213             std::stringstream msg;
214             msg << descName << ": Expected matching number of per-axis quantization scales for weights and bias, "
215                 << "but got different values. This is currently unsupported: weights=" << weightScales.size()
216                 << ", biases=" << biasScales.size();
217             throw InvalidArgumentException(msg.str(), CHECK_LOCATION());
218         }
219 
220         for (size_t i = 0ul; i < biasScales.size(); ++i)
221         {
222             const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightScales[i];
223             VerifyBiasQuantizationScale(biasScales[i], expectedScale);
224         }
225     }
226     else
227     {
228         // Validate per-tensor quantization scale
229         const float expectedScale = inputTensorInfo.GetQuantizationScale() * weightsTensorInfo.GetQuantizationScale();
230         VerifyBiasQuantizationScale(biasTensor.GetQuantizationScale(), expectedScale);
231     }
232 }
233 
234 //---------------------------------------------------------------
ValidateTensors(const std::vector<ITensorHandle * > & vec,unsigned int numExpected,const std::string & descName,const std::string & varName)235 void ValidateTensors(const std::vector<ITensorHandle*>& vec,
236     unsigned int numExpected,
237     const std::string& descName,
238     const std::string& varName)
239 {
240     if (vec.empty() && numExpected > 0)
241     {
242         throw InvalidArgumentException(descName + ": Invalid empty " + varName + " array.");
243     }
244 
245     for (unsigned int i = 0; i < numExpected; ++i)
246     {
247         if (!vec[i])
248         {
249             throw InvalidArgumentException(descName + ": Invalid NULL for " + varName + to_string(i));
250         }
251     }
252 }
253 
254 //---------------------------------------------------------------
ValidateBroadcastTensorShapesMatch(const TensorInfo & first,const TensorInfo & second,const TensorInfo & output,std::string const & descName,std::string const & firstName,std::string const & secondName)255 void ValidateBroadcastTensorShapesMatch(const TensorInfo& first,
256                                         const TensorInfo& second,
257                                         const TensorInfo& output,
258                                         std::string const& descName,
259                                         std::string const& firstName,
260                                         std::string const& secondName)
261 {
262     // Tensors must have the same number of dimensions in order to be explicit about which dimensions will get
263     // broadcasted.
264     if (first.GetNumDimensions() != second.GetNumDimensions())
265     {
266         throw InvalidArgumentException(descName  + ": Tensors "
267             + firstName + " & " + secondName
268             + " must have the same number of dimensions in order to be broadcasted");
269     }
270     uint32_t numDims = first.GetNumDimensions();
271     std::vector<uint32_t> outputDims(numDims, 0u);
272     for (uint32_t i = 0; i < numDims; i++)
273     {
274         const bool dimsNotEqual = first.GetShape()[i] != second.GetShape()[i];
275         const bool dimsNotOne = (first.GetShape()[i] != 1) && (second.GetShape()[i] != 1);
276         if (dimsNotEqual && dimsNotOne)
277         {
278             throw InvalidArgumentException("Broadcasting is not possible for incompatible shapes");
279         }
280         outputDims[i] = std::max(first.GetShape()[i], second.GetShape()[i]);
281     }
282     TensorShape broadcastShape = TensorShape(armnn::numeric_cast<unsigned int>(outputDims.size()), outputDims.data());
283     if (broadcastShape != output.GetShape())
284     {
285         throw InvalidArgumentException(descName + ": The tensor shape resulting from adding "
286                                        + firstName + " & " + secondName
287                                        + " does not match the output shape");
288     }
289 }
290 
291 //---------------------------------------------------------------
ValidateDataTypes(const TensorInfo & info,const std::vector<armnn::DataType> & supportedTypes,std::string const & descName)292 void ValidateDataTypes(const TensorInfo& info,
293                        const std::vector<armnn::DataType>& supportedTypes,
294                        std::string const& descName)
295 {
296     auto iterator = std::find(supportedTypes.begin(), supportedTypes.end(), info.GetDataType());
297     if (iterator == supportedTypes.end())
298     {
299         throw InvalidArgumentException(descName  + ": " + " Tensor type is not supported.");
300     }
301 }
302 
303 //---------------------------------------------------------------
ValidateTensorDataTypesMatch(const TensorInfo & first,const TensorInfo & second,std::string const & descName,std::string const & firstName,std::string const & secondName)304 void ValidateTensorDataTypesMatch(const TensorInfo& first,
305                                   const TensorInfo& second,
306                                   std::string const& descName,
307                                   std::string const& firstName,
308                                   std::string const& secondName)
309 {
310     if (first.GetDataType() != second.GetDataType())
311     {
312         throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
313                                        " must have identical data types.");
314     }
315 }
316 
317 //---------------------------------------------------------------
ValidateTensorNumElementsMatch(const TensorInfo & first,const TensorInfo & second,std::string const & descName,std::string const & firstName,std::string const & secondName)318 void ValidateTensorNumElementsMatch(const TensorInfo& first,
319                                     const TensorInfo& second,
320                                     std::string const& descName,
321                                     std::string const& firstName,
322                                     std::string const& secondName)
323 {
324     if (first.GetNumElements() != second.GetNumElements())
325     {
326         throw InvalidArgumentException(descName + ": " + firstName + " & " + secondName +
327                                        " must have the same number of elements.");
328     }
329 }
330 
ValidateWeightDataType(const TensorInfo & inputInfo,const TensorInfo & weightInfo,const std::string & descName)331 void ValidateWeightDataType(const TensorInfo& inputInfo,
332                             const TensorInfo& weightInfo,
333                             const std::string& descName)
334 {
335     const DataType inputType = inputInfo.GetDataType();
336     if (IsQuantized8BitType(inputType))
337     {
338         const std::vector<DataType> validTypes =
339         {
340             DataType::QAsymmS8,
341             DataType::QAsymmU8,
342             DataType::QSymmS8
343         };
344 
345         ValidateDataTypes(weightInfo, validTypes, descName);
346     }
347     else
348     {
349         ValidateTensorDataTypesMatch(inputInfo, weightInfo, descName, "input", "weight");
350     }
351 }
352 
ValidatePerAxisQuantizationDimension(const TensorInfo & tensorInfo,const std::string & descName,const std::string & tensorName)353 void ValidatePerAxisQuantizationDimension(const TensorInfo& tensorInfo,
354                                           const std::string& descName,
355                                           const std::string& tensorName)
356 {
357     const Optional<unsigned int>& quantizationDim = tensorInfo.GetQuantizationDim();
358     if (!quantizationDim.has_value())
359     {
360         throw InvalidArgumentException(fmt::format("{0}: Quantization dimension for per-axis quantization "
361                                                    "not set on tensor {1}.", descName, tensorName));
362     }
363 }
364 
ValidatePerAxisQuantizationOffset(const TensorInfo & tensorInfo,const std::string & descName,const std::string & tensorName)365 void ValidatePerAxisQuantizationOffset(const TensorInfo& tensorInfo,
366                                        const std::string& descName,
367                                        const std::string& tensorName)
368 {
369     int32_t quantizationOffset = tensorInfo.GetQuantizationOffset();
370     if (quantizationOffset != 0)
371     {
372         throw InvalidArgumentException(fmt::format(
373             "{0}: Quantization offset for per-axis quantization expected to be 0 on tensor {1}, but got: {2}",
374             descName, tensorName, quantizationOffset));
375     }
376 }
377 
ValidatePerAxisQuantization(const TensorInfo & inputInfo,const TensorInfo & outputInfo,const TensorInfo & weightInfo,const Optional<TensorInfo> & optionalBiasInfo,const std::string & descName)378 void ValidatePerAxisQuantization(const TensorInfo& inputInfo,
379                                  const TensorInfo& outputInfo,
380                                  const TensorInfo& weightInfo,
381                                  const Optional<TensorInfo>& optionalBiasInfo,
382                                  const std::string& descName)
383 {
384     if (weightInfo.HasPerAxisQuantization())
385     {
386         const DataType inputDataType  = inputInfo.GetDataType();
387         const DataType outputDataType = outputInfo.GetDataType();
388 
389         const bool canHavePerAxisQuantization = (IsQuantized8BitType(inputDataType)) && inputDataType == outputDataType;
390 
391         if (!canHavePerAxisQuantization)
392         {
393             throw InvalidArgumentException(fmt::format(
394                 "{0}: Per-axis quantization parameters set on tensor {1}, but data type does not support "
395                 "per-axis quantization.", descName, "weight"));
396         }
397 
398 
399         ValidPerAxisQuantizedDataType(weightInfo, descName, "weight");
400         ValidatePerAxisQuantizationDimension(weightInfo, descName, "weight");
401         ValidatePerAxisQuantizationOffset(weightInfo, descName, "weight");
402 
403         if (optionalBiasInfo.has_value())
404         {
405             const TensorInfo& biasInfo = optionalBiasInfo.value();
406             if (!biasInfo.HasPerAxisQuantization())
407             {
408                 throw InvalidArgumentException(fmt::format(
409                         "{}: Per-axis quantization parameters not set on bias tensor, "
410                         "despite being set on weight tensor.", descName));
411             }
412 
413             ValidateTensorDataType(biasInfo, DataType::Signed32, descName, "bias");
414             ValidatePerAxisQuantizationDimension(biasInfo, descName, "bias");
415             ValidatePerAxisQuantizationOffset(biasInfo, descName, "bias");
416         }
417     }
418 }
419 
420 } // anonymous namespace
421 
422 //---------------------------------------------------------------
ValidateTensorNumDimensions(const TensorInfo & tensor,std::string const & descName,unsigned int numDimensions,std::string const & tensorName) const423 void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
424                                                   std::string const& descName,
425                                                   unsigned int numDimensions,
426                                                   std::string const& tensorName) const
427 {
428     // If we're allowing expanded dimensions then numDimensions becomes the minimum number of Dimensions we can allow.
429     // Throw an Exception if the tensors has fewer than numDimensions or if the squeezed dimensions are greater than
430     // numDimensions.
431     if (m_AllowExpandedDims)
432     {
433         unsigned int squeezedDims = 0;
434 
435         for (unsigned int i = 0; i < tensor.GetNumDimensions(); ++i)
436         {
437             if (tensor.GetShape()[i] != 1)
438             {
439                 ++squeezedDims;
440             }
441         }
442         if (tensor.GetNumDimensions() < numDimensions || squeezedDims > numDimensions)
443         {
444             throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " or less but got " +
445                                            to_string(tensor.GetNumDimensions()) + " dimensions for " +
446                                            tensorName + " tensor.");
447         }
448     }
449     else
450     {
451         if (tensor.GetNumDimensions() != numDimensions)
452         {
453             throw InvalidArgumentException(descName + ": Expected " + to_string(numDimensions) + " but got " +
454                                            to_string(tensor.GetNumDimensions()) + " dimensions for " +
455                                            tensorName + " tensor.");
456         }
457     }
458 }
459 
460 //---------------------------------------------------------------
ValidateTensorNumDimNumElem(const TensorInfo & tensorInfo,unsigned int numDimension,unsigned int numElements,std::string const & tensorName) const461 void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
462                                  unsigned int numDimension,
463                                  unsigned int numElements,
464                                  std::string const& tensorName) const
465 {
466     const std::string functionName{"ValidateTensorNumDimNumElem"};
467     ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
468     ValidateTensorNumElements(tensorInfo, functionName, numElements, tensorName);
469 }
470 
471 //---------------------------------------------------------------
ValidateInputsOutputs(const std::string & descName,unsigned int numExpectedIn,unsigned int numExpectedOut) const472 void QueueDescriptor::ValidateInputsOutputs(const std::string& descName,
473     unsigned int numExpectedIn, unsigned int numExpectedOut) const
474 {
475     ValidateTensors(m_Inputs, numExpectedIn, descName, "input");
476     ValidateTensors(m_Outputs, numExpectedOut, descName, "output");
477 }
478 
479 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const480 void MapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
481 {
482     const std::string descriptorName{"MapQueueDescriptor"};
483 
484     ValidateNumInputs(workloadInfo,  descriptorName, 1);
485     ValidateNumOutputs(workloadInfo, descriptorName, 0);
486 
487     for (unsigned int i = 0; i < m_Inputs.size(); ++i)
488     {
489         if (!m_Inputs[i])
490         {
491             throw InvalidArgumentException(
492                 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
493         }
494     }
495 }
496 
497 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const498 void UnmapQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
499 {
500     const std::string descriptorName{"UnmapQueueDescriptor"};
501 
502     ValidateNumInputs(workloadInfo,  descriptorName, 1);
503     ValidateNumOutputs(workloadInfo, descriptorName, 0);
504 
505     for (unsigned int i = 0; i < m_Inputs.size(); ++i)
506     {
507         if (!m_Inputs[i])
508         {
509             throw InvalidArgumentException(
510                 fmt::format("{}: Invalid NULL input {}.", descriptorName, static_cast<int>(i)));
511         }
512     }
513 }
514 
515 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const516 void MemCopyQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
517 {
518     const std::string descriptorName{"MemCopyQueueDescriptor"};
519 
520     ValidateNumInputs(workloadInfo,  descriptorName, 1);
521     ValidateNumOutputs(workloadInfo, descriptorName , 1);
522 
523     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
524     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
525 
526     ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
527     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
528 
529     if (m_Inputs.size() != m_Outputs.size())
530     {
531         throw InvalidArgumentException(fmt::format(
532             "{0}: Number of inputs ({1}) does not match the number of outputs ({2}).",
533             descriptorName, m_Inputs.size(), m_Outputs.size()));
534     }
535 
536     for (unsigned int i = 0; i < m_Inputs.size(); ++i)
537     {
538         if (!m_Inputs[i])
539         {
540             throw InvalidArgumentException(fmt::format(
541                 "{0}: Invalid NULL input {1}.", descriptorName, i));
542         }
543 
544         if (!m_Outputs[i])
545         {
546             throw InvalidArgumentException(fmt::format("{0}: Invalid NULL output {1}", descriptorName, i));
547         }
548     }
549 }
550 
551 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const552 void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
553 {
554     ValidateNumInputs(workloadInfo, "MemImportQueueDescriptor", 1);
555     ValidateNumOutputs(workloadInfo, "MemImportQueueDescriptor" , 1);
556 
557     if (workloadInfo.m_InputTensorInfos.size() != 1)
558     {
559         throw InvalidArgumentException(fmt::format("Number of input infos ({}) is not 1.",
560                                                    workloadInfo.m_InputTensorInfos.size()));
561 
562     }
563 
564     if (workloadInfo.m_InputTensorInfos.size() != workloadInfo.m_OutputTensorInfos.size())
565     {
566         throw InvalidArgumentException(fmt::format(
567             "Number of input infos ({0}) does not match the number of output infos ({1})",
568             workloadInfo.m_InputTensorInfos.size(), workloadInfo.m_OutputTensorInfos.size()));
569     }
570 
571     for (std::size_t i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
572     {
573         if (workloadInfo.m_InputTensorInfos[i].GetNumElements() !=
574             workloadInfo.m_OutputTensorInfos[i].GetNumElements())
575         {
576             throw InvalidArgumentException(fmt::format(
577                 "Number of elements for tensor input and output {} does not match", i ));
578         }
579     }
580 
581     if (m_Inputs.size() != 1)
582     {
583         throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
584     }
585 
586     if (m_Inputs.size() != m_Outputs.size())
587     {
588         throw InvalidArgumentException(fmt::format(
589             "Number of inputs ({0}) does not match the number of outputs ({1})",
590             m_Inputs.size(), m_Outputs.size()));
591     }
592 
593     for (unsigned int i = 0; i < m_Inputs.size(); ++i)
594     {
595         if (!m_Inputs[i])
596         {
597             throw InvalidArgumentException(fmt::format("Invalid null input {}", i));
598         }
599 
600         if (!m_Outputs[i])
601         {
602             throw InvalidArgumentException(fmt::format("Invalid null output {}", i));
603         }
604     }
605 }
606 
607 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const608 void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
609 {
610     ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
611 
612     if (m_Inputs.size() != 1)
613     {
614         throw InvalidArgumentException(fmt::format("Number of inputs ({}) is not 1.", m_Inputs.size()));
615     }
616 
617     if (m_Outputs.size() != 0)
618     {
619         throw InvalidArgumentException(fmt::format("Number of outputs ({}) is not 0.", m_Outputs.size()));
620     }
621 
622     if (!m_Inputs[0])
623     {
624         throw InvalidArgumentException(fmt::format("Invalid null input 0"));
625     }
626 }
627 
628 //---------------------------------------------------------------
Validate(const WorkloadInfo & workloadInfo) const629 void ActivationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
630 {
631     const std::string descriptorName{"ActivationQueueDescriptor"};
632 
633     ValidateNumInputs(workloadInfo,  descriptorName, 1);
634     ValidateNumOutputs(workloadInfo, descriptorName, 1);
635 
636     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
637     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
638 
639     std::vector<DataType> supportedTypes =
640     {
641         DataType::BFloat16,
642         DataType::Float16,
643         DataType::Float32,
644         DataType::QAsymmS8,
645         DataType::QAsymmU8,
646         DataType::QSymmS16
647     };
648 
649     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
650     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
651     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
652 }
653 
Validate(const WorkloadInfo & workloadInfo) const654 void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
655 {
656     const std::string descriptorName{"ArgMinMaxQueueDescriptor"};
657 
658     ValidateNumInputs(workloadInfo,  descriptorName, 1);
659     ValidateNumOutputs(workloadInfo, descriptorName, 1);
660 
661     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
662     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
663 
664     if (outputTensorInfo.GetDataType() != DataType::Signed32 &&
665         outputTensorInfo.GetDataType() != DataType::Signed64)
666     {
667         throw InvalidArgumentException(descriptorName + ": Output of ArgMinMax layer must be Int32 or Int64.");
668     }
669 
670     std::vector<DataType> supportedInputTypes =
671     {
672         DataType::BFloat16,
673         DataType::Float16,
674         DataType::Float32,
675         DataType::QAsymmS8,
676         DataType::QAsymmU8,
677         DataType::QSymmS16,
678         DataType::Signed32,
679         DataType::Signed64
680     };
681 
682     ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
683 
684     auto inputShape = inputTensorInfo.GetShape();
685     auto outputShape = outputTensorInfo.GetShape();
686 
687     auto inputNumDimensions = inputShape.GetNumDimensions();
688     auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
689 
690     const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
691 
692     // 1D input shape results in scalar output shape
693     if (inputShape.GetNumDimensions() == 1)
694     {
695         if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
696         {
697             throw InvalidArgumentException(descriptorName + outputShapeError);
698         }
699     }
700     else
701     {
702         for (unsigned int i = 0; i < unsignedAxis; ++i)
703         {
704             if (outputShape[i] != inputShape[i])
705             {
706                 throw InvalidArgumentException(descriptorName + outputShapeError);
707             }
708         }
709 
710         for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
711         {
712             if (outputShape[i - 1] != inputShape[i])
713             {
714                 throw InvalidArgumentException(descriptorName + outputShapeError);
715             }
716         }
717     }
718 }
719 
Validate(const WorkloadInfo & workloadInfo) const720 void CastQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
721 {
722     const std::string descriptorName{"CastQueueDescriptor"};
723 
724     ValidateNumInputs(workloadInfo,  descriptorName, 1);
725     ValidateNumOutputs(workloadInfo, descriptorName, 1);
726 
727     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
728     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
729 
730     std::vector<DataType> supportedTypes =
731     {
732             DataType::BFloat16,
733             DataType::Float16,
734             DataType::Float32,
735             DataType::QAsymmS8,
736             DataType::QAsymmU8,
737             DataType::QSymmS8,
738             DataType::QSymmS16,
739             DataType::Signed32,
740             DataType::Signed64
741     };
742 
743     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
744     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
745 }
746 
Validate(const WorkloadInfo & workloadInfo) const747 void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
748 {
749     const std::string descriptorName{"SoftmaxQueueDescriptor"};
750 
751     ValidateNumInputs(workloadInfo,  descriptorName, 1);
752     ValidateNumOutputs(workloadInfo, descriptorName, 1);
753 
754     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
755     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
756 
757     std::vector<DataType> supportedTypes =
758     {
759         DataType::BFloat16,
760         DataType::Float16,
761         DataType::Float32,
762         DataType::QAsymmS8,
763         DataType::QAsymmU8,
764         DataType::QSymmS16
765     };
766 
767     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
768     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
769     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
770 }
771 
Validate(const WorkloadInfo & workloadInfo) const772 void SplitterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
773 {
774     const std::string descriptorName{"SplitterQueueDescriptor"};
775 
776     ValidateNumInputs(workloadInfo, descriptorName, 1);
777 
778     // Check the supported data types
779     std::vector<DataType> supportedTypes =
780     {
781         DataType::BFloat16,
782         DataType::Float32,
783         DataType::Float16,
784         DataType::Boolean,
785         DataType::Signed32,
786         DataType::QAsymmS8,
787         DataType::QAsymmU8,
788         DataType::QSymmS16
789     };
790 
791     const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
792     for (unsigned long i = 0ul; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
793     {
794         const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[i];
795         ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
796 
797         const std::string outputName = "output_" + std::to_string(i);
798         ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", outputName);
799     }
800 
801     if (workloadInfo.m_OutputTensorInfos.size() <= 0)
802     {
803         throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
804     }
805 
806     if (workloadInfo.m_OutputTensorInfos.size() != m_ViewOrigins.size())
807     {
808         throw InvalidArgumentException(
809             descriptorName + ": Number of split windows "
810             "has to match number of workloadInfo.m_OutputTensorInfos. "
811             "Number of windows: " +
812             to_string(m_ViewOrigins.size()) +
813             ". Number of workloadInfo.m_OutputTensorInfos: " + to_string(workloadInfo.m_OutputTensorInfos.size()));
814     }
815 
816     //The dimensionality of all the windows has to match the dimensionality (not shape) of the input.
817     std::size_t inputDims = workloadInfo.m_InputTensorInfos[0].GetNumDimensions();
818     for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
819     {
820         //Checks that the dimensionality of input is same as the split windows.
821         ViewOrigin const& e = m_ViewOrigins[w];
822         if (e.m_Origin.size() != inputDims)
823         {
824             throw InvalidArgumentException(descriptorName + ": Window origin have to "
825                                            "have the same dimensionality as the input tensor. "
826                                            "Window origin (index: " +
827                                            to_string(w) + ") has " + to_string(e.m_Origin.size()) +
828                                            " dimensions, the input "
829                                            "tensor has " +
830                                            to_string(inputDims) + " dimensions.");
831         }
832         for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
833         {
834             if (e.m_Origin[i] + workloadInfo.m_OutputTensorInfos[w].GetShape()[i] >
835                 workloadInfo.m_InputTensorInfos[0].GetShape()[i])
836             {
837                 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
838                                                "be smaller or equal than the size of the input in that coord.");
839             }
840         }
841     }
842 }
843 
Validate(const WorkloadInfo & workloadInfo) const844 void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
845 {
846     const std::string descriptorName{"ConcatQueueDescriptor"};
847 
848     ValidateNumOutputs(workloadInfo, descriptorName, 1);
849 
850     if (m_Inputs.size() <= 0)
851     {
852         throw InvalidArgumentException(descriptorName + ": At least one input needs to be provided.");
853     }
854     if (m_Outputs.size() <= 0)
855     {
856         throw InvalidArgumentException(descriptorName + ": At least one output needs to be provided.");
857     }
858 
859     if (workloadInfo.m_InputTensorInfos.size() <= 0)
860     {
861         throw InvalidArgumentException(descriptorName + ": At least one TensorInfo input needs to be provided.");
862     }
863     if (workloadInfo.m_OutputTensorInfos.size() <= 0)
864     {
865         throw InvalidArgumentException(descriptorName + ": At least one TensorInfo output needs to be provided.");
866     }
867 
868     if(m_Parameters.GetConcatAxis() > workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions())
869     {
870         throw InvalidArgumentException(descriptorName + ": Invalid concatenation axis provided.");
871     }
872 
873     if (workloadInfo.m_InputTensorInfos[0].GetShape().GetNumDimensions() - m_Parameters.GetConcatAxis() == 1)
874     {
875         return;
876     }
877 
878     if (workloadInfo.m_InputTensorInfos.size() != m_ViewOrigins.size())
879     {
880         throw InvalidArgumentException(
881             descriptorName + ": Number of split windows "
882             "has to match number of workloadInfo.m_InputTensorInfos. "
883             "Number of windows: " +
884             to_string(m_ViewOrigins.size()) +
885             ". Number of workloadInfo.m_InputTensorInfos: " + to_string(workloadInfo.m_InputTensorInfos.size()));
886     }
887 
888     //The dimensionality of all the windows has to match the dimensionality (not shape) of the output.
889     std::size_t outputDims = workloadInfo.m_OutputTensorInfos[0].GetNumDimensions();
890     for(unsigned int w = 0; w < m_ViewOrigins.size(); ++w )
891     {
892         //Checks that the dimensionality of output is same as the split windows.
893         ViewOrigin const& e = m_ViewOrigins[w];
894         if (e.m_Origin.size() != outputDims)
895         {
896             throw InvalidArgumentException(descriptorName + ": Window origin have to "
897                                            "have the same dimensionality as the output tensor. "
898                                            "Window origin (index: " +
899                                            to_string(w) + ") has " + to_string(e.m_Origin.size()) +
900                                            " dimensions, the output "
901                                            "tensor has " +
902                                            to_string(outputDims) + " dimensions.");
903         }
904         //Checks that the merge windows are within the output tensor.
905         for (unsigned int i = 0; i < e.m_Origin.size(); ++i)
906         {
907             if (e.m_Origin[i] + workloadInfo.m_InputTensorInfos[w].GetShape()[i]
908                 > workloadInfo.m_OutputTensorInfos[0].GetShape()[i])
909             {
910                 throw InvalidArgumentException(descriptorName + ": Window extent coordinates have to "
911                                                "be smaller or equal than the size of the output in that coord.");
912             }
913         }
914     }
915 
916     // Check the supported data types
917     std::vector<DataType> supportedTypes =
918     {
919         DataType::BFloat16,
920         DataType::Float32,
921         DataType::Float16,
922         DataType::Boolean,
923         DataType::Signed32,
924         DataType::QAsymmS8,
925         DataType::QAsymmU8,
926         DataType::QSymmS16
927     };
928 
929     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
930     for (unsigned long i = 0ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
931     {
932         const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[i];
933         ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
934 
935         const std::string inputName = "input_" + std::to_string(i);
936         ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, inputName, "output");
937     }
938 }
939 
Validate(const WorkloadInfo & workloadInfo) const940 void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
941 {
942     const std::string descriptorName{"StackQueueDescriptor"};
943 
944     ValidateNumOutputs(workloadInfo, descriptorName, 1);
945 
946     if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
947     {
948         throw InvalidArgumentException(descriptorName + ": Must have the defined number of input tensors.");
949     }
950 
951     // All inputs must have the same shape, which is defined in parameters
952     const TensorShape& inputShape = m_Parameters.m_InputShape;
953     for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
954     {
955         if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
956         {
957             throw InvalidArgumentException(descriptorName + ": All input tensor shapes must match the defined shape.");
958         }
959     }
960 
961     if (inputShape.GetNumDimensions() > 4)
962     {
963         throw InvalidArgumentException(descriptorName + ": Input tensor may have up to 4 dimensions.");
964     }
965 
966     // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
967     // since the output tensor has an additional dimension.
968     if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
969     {
970         throw InvalidArgumentException(descriptorName + ": Axis may not be greater "
971                                        "than the number of input dimensions.");
972     }
973 
974     // Output shape must be as inferred from the input shape
975     const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
976     for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
977     {
978         if (outputShape[i] != inputShape[i])
979         {
980             throw InvalidArgumentException(descriptorName + ": Output tensor must "
981                                            "match shape inferred from input tensor.");
982         }
983     }
984 
985     if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
986     {
987         throw InvalidArgumentException(descriptorName + ": Output tensor must "
988                                        "match shape inferred from input tensor.");
989     }
990 
991     for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
992     {
993         if (outputShape[i] != inputShape[i-1])
994         {
995             throw InvalidArgumentException(descriptorName + ": Output tensor must "
996                                            "match shape inferred from input tensor.");
997         }
998     }
999 
1000     if (outputShape.GetNumDimensions() > 5)
1001     {
1002         throw InvalidArgumentException(descriptorName + ": Output tensor may have up to 5 dimensions.");
1003     }
1004 
1005     // Check the supported data types
1006     std::vector<DataType> supportedTypes =
1007     {
1008         DataType::BFloat16,
1009         DataType::Float32,
1010         DataType::Float16,
1011         DataType::Boolean,
1012         DataType::Signed32,
1013         DataType::QAsymmS8,
1014         DataType::QAsymmU8,
1015         DataType::QSymmS16
1016     };
1017 
1018     ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1019 
1020     for (unsigned int i = 1ul; i < workloadInfo.m_InputTensorInfos.size(); ++i)
1021     {
1022         ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1023                                      workloadInfo.m_InputTensorInfos[i],
1024                                      descriptorName,
1025                                      "input_0",
1026                                      "input_" + std::to_string(i));
1027     }
1028 
1029     ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1030                                  workloadInfo.m_OutputTensorInfos[0],
1031                                  descriptorName,
1032                                  "input_0",
1033                                  "output");
1034 }
1035 
Validate(const WorkloadInfo & workloadInfo) const1036 void FillQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1037 {
1038     const std::string descriptorName{"FillQueueDescriptor"};
1039 
1040     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1041     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1042 
1043     const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
1044     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1045 
1046     ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 1, "input");
1047 
1048     std::vector<DataType> supportedTypes =
1049     {
1050         DataType::BFloat16,
1051         DataType::Float32,
1052         DataType::Float16,
1053         DataType::Signed32
1054     };
1055 
1056     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1057 }
1058 
Validate(const WorkloadInfo & workloadInfo) const1059 void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1060 {
1061     const std::string descriptorName{"FullyConnectedQueueDescriptor"};
1062 
1063     uint32_t numInputs = 2;
1064     if (m_Parameters.m_BiasEnabled)
1065     {
1066         numInputs = 3;
1067     }
1068 
1069     ValidateNumInputs(workloadInfo, descriptorName, numInputs);
1070     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1071 
1072     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1073     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1074 
1075     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1076 
1077     if (!(inputTensorInfo.GetNumDimensions() == 2 || inputTensorInfo.GetNumDimensions() == 4))
1078     {
1079         throw InvalidArgumentException(descriptorName + ": Input tensor must have 2 or 4 dimensions.");
1080     }
1081 
1082     TensorInfo weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1083     ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 2, "weight");
1084 
1085     if (m_Parameters.m_BiasEnabled)
1086     {
1087         TensorInfo biasTensorInfo = workloadInfo.m_InputTensorInfos[2];
1088         // Validates type and quantization values.
1089         ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1090         ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1091         ValidateTensorNumDimensions(biasTensorInfo, descriptorName, 1, "bias");
1092     }
1093 
1094     // Check the supported data types
1095     std::vector<DataType> supportedTypes =
1096     {
1097         DataType::BFloat16,
1098         DataType::Float32,
1099         DataType::Float16,
1100         DataType::QAsymmS8,
1101         DataType::QAsymmU8,
1102         DataType::QSymmS16
1103     };
1104 
1105     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1106 
1107     // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
1108     if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1109     {
1110         if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1111         {
1112             throw InvalidArgumentException(descriptorName  + ": " + " Output tensor type must be BFloat16 or Float32 "
1113                                            "for BFloat16 input.");
1114         }
1115     }
1116     else
1117     {
1118         ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1119     }
1120 }
1121 
Validate(const WorkloadInfo & workloadInfo) const1122 void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1123 {
1124     const std::string descriptorName{"NormalizationQueueDescriptor"};
1125 
1126     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1127     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1128 
1129     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1130     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1131 
1132     // Check the supported data types
1133     std::vector<DataType> supportedTypes =
1134     {
1135         DataType::BFloat16,
1136         DataType::Float16,
1137         DataType::Float32,
1138         DataType::QAsymmS8,
1139         DataType::QAsymmU8,
1140         DataType::QSymmS16
1141     };
1142 
1143     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1144 
1145     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1146 
1147     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1148 }
1149 
Validate(const WorkloadInfo & workloadInfo) const1150 void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1151 {
1152     const std::string descriptorName{"AdditionQueueDescriptor"};
1153 
1154     ValidateNumInputs(workloadInfo,  descriptorName, 2);
1155     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1156 
1157     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1158     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1159     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1160 
1161     std::vector<DataType> supportedTypes =
1162     {
1163         DataType::BFloat16,
1164         DataType::Float32,
1165         DataType::Float16,
1166         DataType::QAsymmS8,
1167         DataType::QAsymmU8,
1168         DataType::QSymmS16,
1169         DataType::Signed32
1170     };
1171 
1172     ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1173     ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1174     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1175 
1176     ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1177     ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
1178 
1179     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1180                                        inputTensorInfo1,
1181                                        outputTensorInfo,
1182                                        descriptorName,
1183                                        "input_0",
1184                                        "input_1");
1185 }
1186 
Validate(const WorkloadInfo & workloadInfo) const1187 void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1188 {
1189     const std::string descriptorName{"MultiplicationQueueDescriptor"};
1190 
1191     ValidateNumInputs(workloadInfo,  descriptorName, 2);
1192     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1193 
1194     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
1195     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
1196     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1197 
1198     std::vector<DataType> supportedTypes =
1199     {
1200         DataType::BFloat16,
1201         DataType::Float16,
1202         DataType::Float32,
1203         DataType::QAsymmS8,
1204         DataType::QAsymmU8,
1205         DataType::QSymmS16,
1206         DataType::Signed32
1207     };
1208 
1209     ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
1210     ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
1211     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1212 
1213     ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
1214     ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input_1", "output");
1215 
1216     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
1217                                        inputTensorInfo1,
1218                                        outputTensorInfo,
1219                                        descriptorName,
1220                                        "input_0",
1221                                        "input_1");
1222 }
1223 
Validate(const WorkloadInfo & workloadInfo) const1224 void BatchNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1225 {
1226     const std::string descriptorName{"BatchNormalizationQueueDescriptor"};
1227 
1228     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1229     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1230 
1231     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1232     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1233 
1234     std::vector<DataType> supportedTypes =
1235     {
1236         DataType::BFloat16,
1237         DataType::Float16,
1238         DataType::Float32,
1239         DataType::QAsymmS8,
1240         DataType::QAsymmU8,
1241         DataType::QSymmS16
1242     };
1243 
1244     ValidateDataTypes(inputTensorInfo,  supportedTypes, descriptorName);
1245     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1246 
1247     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1248     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1249 
1250     ValidatePointer(m_Mean,     descriptorName, "mean");
1251     ValidatePointer(m_Variance, descriptorName, "variance");
1252     ValidatePointer(m_Beta,     descriptorName, "beta");
1253     ValidatePointer(m_Gamma,    descriptorName, "gamma");
1254 
1255     const TensorInfo& mean     = m_Mean->GetTensorInfo();
1256     const TensorInfo& variance = m_Variance->GetTensorInfo();
1257     const TensorInfo& beta     = m_Beta->GetTensorInfo();
1258     const TensorInfo& gamma    = m_Gamma->GetTensorInfo();
1259 
1260     ValidateTensorNumDimensions(mean,     descriptorName, 1, "mean");
1261     ValidateTensorNumDimensions(variance, descriptorName, 1, "variance");
1262     ValidateTensorNumDimensions(beta,     descriptorName, 1, "beta");
1263     ValidateTensorNumDimensions(gamma,    descriptorName, 1, "gamma");
1264 
1265     ValidateTensorShapesMatch(mean, variance, descriptorName, "mean", "variance");
1266     ValidateTensorShapesMatch(mean, beta,     descriptorName, "mean", "beta");
1267     ValidateTensorShapesMatch(mean, gamma,    descriptorName, "mean", "gamma");
1268 }
1269 
Validate(const WorkloadInfo & workloadInfo) const1270 void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1271 {
1272     const std::string descriptorName{"Convolution2dQueueDescriptor"};
1273 
1274     uint32_t numInputs = 2;
1275     if (m_Parameters.m_BiasEnabled)
1276     {
1277         numInputs = 3;
1278     }
1279 
1280     ValidateNumInputs(workloadInfo,  descriptorName, numInputs);
1281     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1282 
1283     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1284     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1285 
1286     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 4, "input");
1287     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1288 
1289     const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1290 
1291     ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1292 
1293     ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1294 
1295     Optional<TensorInfo> optionalBiasTensorInfo;
1296     if (m_Parameters.m_BiasEnabled)
1297     {
1298         optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
1299         const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1300 
1301         ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1302         ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1303     }
1304 
1305     if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0  )
1306     {
1307         throw InvalidArgumentException(
1308             fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1309                         "cannot be either negative or 0.",
1310                         descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1311     }
1312 
1313     ValidatePerAxisQuantization(inputTensorInfo,
1314                                 outputTensorInfo,
1315                                 weightTensorInfo,
1316                                 optionalBiasTensorInfo,
1317                                 descriptorName);
1318 
1319     std::vector<DataType> supportedTypes =
1320     {
1321         DataType::BFloat16,
1322         DataType::Float16,
1323         DataType::Float32,
1324         DataType::QAsymmS8,
1325         DataType::QAsymmU8,
1326         DataType::QSymmS16,
1327         DataType::QSymmS8
1328     };
1329 
1330     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1331 
1332     // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
1333     if (inputTensorInfo.GetDataType() == DataType::BFloat16)
1334     {
1335         if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
1336         {
1337             throw InvalidArgumentException(descriptorName  + ": " + " Output tensor type must be BFloat16 or Float32 "
1338                                            "for BFloat16 input.");
1339         }
1340     }
1341     else
1342     {
1343         ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1344     }
1345 }
1346 
Validate(const WorkloadInfo & workloadInfo) const1347 void Convolution3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1348 {
1349     const std::string descriptorName{"Convolution3dQueueDescriptor"};
1350 
1351     uint32_t numInputs = 2;
1352     if (m_Parameters.m_BiasEnabled)
1353     {
1354         numInputs = 3;
1355     }
1356     ValidateNumInputs(workloadInfo,  descriptorName, numInputs);
1357     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1358 
1359     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1360     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1361 
1362     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 5, "input");
1363     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1364 
1365     const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1366     ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 5, "weight");
1367 
1368     ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1369 
1370     Optional<TensorInfo> optionalBiasTensorInfo;
1371     if (m_Parameters.m_BiasEnabled)
1372     {
1373         optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
1374         const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1375 
1376         ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1377         ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1378     }
1379 
1380     if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0 || m_Parameters.m_StrideZ <= 0 )
1381     {
1382         throw InvalidArgumentException(
1383                 fmt::format("{}: strideX (provided {}), strideY (provided {}) or strideZ (provided {})"
1384                             "cannot be either negative or 0.",
1385                             descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY, m_Parameters.m_StrideZ));
1386     }
1387 
1388     ValidatePerAxisQuantization(inputTensorInfo,
1389                                 outputTensorInfo,
1390                                 weightTensorInfo,
1391                                 optionalBiasTensorInfo,
1392                                 descriptorName);
1393 
1394     std::vector<DataType> supportedTypes =
1395     {
1396         DataType::BFloat16,
1397         DataType::Float16,
1398         DataType::Float32,
1399         DataType::QAsymmS8,
1400         DataType::QAsymmU8,
1401         DataType::QSymmS16,
1402         DataType::QSymmS8
1403     };
1404 
1405     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1406     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1407 }
1408 
Validate(const WorkloadInfo & workloadInfo) const1409 void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1410 {
1411     const std::string descriptorName{"DepthwiseConvolution2dQueueDescriptor"};
1412 
1413     uint32_t numInputs = 2;
1414     if (m_Parameters.m_BiasEnabled)
1415     {
1416         numInputs = 3;
1417     }
1418 
1419     ValidateNumInputs(workloadInfo,  descriptorName, numInputs);
1420     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1421 
1422     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1423     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1424 
1425     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 4, "input");
1426     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1427 
1428     const TensorInfo& weightTensorInfo = workloadInfo.m_InputTensorInfos[1];
1429     ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
1430 
1431     if (m_Parameters.m_DilationX < 1 || m_Parameters.m_DilationY < 1 )
1432     {
1433         throw InvalidArgumentException(
1434             fmt::format("{}: dilationX (provided {}) and dilationY (provided {}) "
1435                         "cannot be smaller than 1.",
1436                         descriptorName, m_Parameters.m_DilationX, m_Parameters.m_DilationX));
1437     }
1438 
1439     if (m_Parameters.m_StrideX <= 0 || m_Parameters.m_StrideY <= 0  )
1440     {
1441         throw InvalidArgumentException(
1442             fmt::format("{}: strideX (provided {}) and strideY (provided {}) "
1443                         "cannot be either negative or 0.",
1444                         descriptorName, m_Parameters.m_StrideX, m_Parameters.m_StrideY));
1445     }
1446 
1447     if (weightTensorInfo.GetShape()[0] != 1)
1448     {
1449         throw InvalidArgumentException(fmt::format(
1450                 "{0}: The weight format in armnn is expected to be [1, H, W, Cout]."
1451                 "But first dimension is not equal to 1. Provided weight shape: [{1}, {2}, {3}, {4}]",
1452                 descriptorName,
1453                 weightTensorInfo.GetShape()[0],
1454                 weightTensorInfo.GetShape()[1],
1455                 weightTensorInfo.GetShape()[2],
1456                 weightTensorInfo.GetShape()[3]));
1457     }
1458 
1459     const unsigned int channelIndex = (m_Parameters.m_DataLayout == DataLayout::NCHW) ? 1 : 3;
1460     const unsigned int numWeightOutputChannelsRefFormat = weightTensorInfo.GetShape()[3];
1461     const unsigned int numWeightOutputChannelsAclFormat = weightTensorInfo.GetShape()[1];
1462     const unsigned int numOutputChannels = outputTensorInfo.GetShape()[channelIndex];
1463 
1464     // Weights format has two valid options: [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] (CpuAcc/GpuAcc).
1465     bool validRefFormat = (numWeightOutputChannelsRefFormat == numOutputChannels);
1466     bool validAclFormat = (numWeightOutputChannelsAclFormat == numOutputChannels);
1467 
1468     if (!(validRefFormat || validAclFormat))
1469     {
1470         throw InvalidArgumentException(fmt::format(
1471             "{0}: The weight format in armnn is expected to be [1, H, W, Cout] (CpuRef) or [1, Cout, H, W] "
1472             "(CpuAcc/GpuAcc). But neither the 4th (CpuRef) or 2nd (CpuAcc/GpuAcc) dimension is equal to Cout."
1473             "Cout = {1} Provided weight shape: [{2}, {3}, {4}, {5}]",
1474             descriptorName,
1475             numOutputChannels,
1476             weightTensorInfo.GetShape()[0],
1477             weightTensorInfo.GetShape()[1],
1478             weightTensorInfo.GetShape()[2],
1479             weightTensorInfo.GetShape()[3]));
1480     }
1481 
1482     ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
1483 
1484     Optional<TensorInfo> optionalBiasTensorInfo;
1485     if (m_Parameters.m_BiasEnabled)
1486     {
1487         optionalBiasTensorInfo = MakeOptional<TensorInfo>(workloadInfo.m_InputTensorInfos[2]);
1488         const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
1489 
1490         ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
1491         ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
1492     }
1493     ValidatePerAxisQuantization(inputTensorInfo,
1494                                 outputTensorInfo,
1495                                 weightTensorInfo,
1496                                 optionalBiasTensorInfo,
1497                                 descriptorName);
1498 
1499     std::vector<DataType> supportedTypes =
1500     {
1501         DataType::BFloat16,
1502         DataType::Float16,
1503         DataType::Float32,
1504         DataType::QAsymmS8,
1505         DataType::QAsymmU8,
1506         DataType::QSymmS16
1507     };
1508 
1509     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1510     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1511 }
1512 
Validate(const WorkloadInfo & workloadInfo) const1513 void PermuteQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1514 {
1515     const std::string descriptorName{"PermuteQueueDescriptor"};
1516 
1517     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1518     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1519 
1520     const PermutationVector& mapping = m_Parameters.m_DimMappings;
1521 
1522     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1523     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1524 
1525     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, mapping.GetSize(), "input");
1526     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
1527 
1528     for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
1529     {
1530         if (inputTensorInfo.GetShape()[i] != outputTensorInfo.GetShape()[mapping[i]])
1531         {
1532             throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(i) +
1533                                            " (=" + to_string(inputTensorInfo.GetShape()[i]) + ") " +
1534                                            "must match dst dimension " + to_string(mapping[i]) +
1535                                            " (=" + to_string(outputTensorInfo.GetShape()[mapping[i]]) + ")");
1536         }
1537     }
1538 
1539     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1540 }
1541 
Validate(const WorkloadInfo & workloadInfo) const1542 void Pooling2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1543 {
1544     const std::string descriptorName{"Pooling2dQueueDescriptor"};
1545 
1546     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1547     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1548 
1549     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1550     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1551 
1552     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 4, "input");
1553     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1554 
1555     std::vector<DataType> supportedTypes =
1556     {
1557         DataType::BFloat16,
1558         DataType::Float32,
1559         DataType::Float16,
1560         DataType::QAsymmS8,
1561         DataType::QAsymmU8,
1562         DataType::QSymmS16
1563     };
1564 
1565     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1566     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1567 }
1568 
Validate(const WorkloadInfo & workloadInfo) const1569 void Pooling3dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1570 {
1571     const std::string descriptorName{"Pooling3dQueueDescriptor"};
1572 
1573     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1574     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1575 
1576     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1577     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1578 
1579     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 5, "input");
1580     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 5, "output");
1581 
1582     std::vector<DataType> supportedTypes =
1583     {
1584         DataType::BFloat16,
1585         DataType::Float32,
1586         DataType::Float16,
1587         DataType::QAsymmS8,
1588         DataType::QAsymmU8,
1589         DataType::QSymmS16
1590     };
1591 
1592     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1593     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1594 }
1595 
Validate(const WorkloadInfo & workloadInfo) const1596 void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1597 {
1598     const std::string descriptorName{"ResizeQueueDescriptor"};
1599 
1600     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1601     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1602 
1603     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1604     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1605 
1606     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 4, "input");
1607     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1608 
1609     std::vector<DataType> supportedTypes =
1610     {
1611         DataType::BFloat16,
1612         DataType::Float16,
1613         DataType::Float32,
1614         DataType::QAsymmS8,
1615         DataType::QAsymmU8,
1616         DataType::QSymmS16
1617     };
1618 
1619     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1620     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1621 
1622     // Resize only changes width and height: batch and channel count must match.
1623     const unsigned int inputBatchSize  = inputTensorInfo.GetShape()[0];
1624     const unsigned int outputBatchSize = outputTensorInfo.GetShape()[0];
1625     if (inputBatchSize != outputBatchSize)
1626     {
1627         throw InvalidArgumentException(
1628                 fmt::format("{}: Input batch size ({}) does not match output batch size ({})",
1629                             descriptorName, inputBatchSize, outputBatchSize));
1630     }
1631 
1632     DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1633     const unsigned int inputChannelCount  = inputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1634     const unsigned int outputChannelCount = outputTensorInfo.GetShape()[dimensionIndices.GetChannelsIndex()];
1635     if (inputChannelCount != outputChannelCount)
1636     {
1637         throw InvalidArgumentException(
1638                 fmt::format("{}: Input channel count ({}) does not match output channel count ({})",
1639                             descriptorName, inputChannelCount, outputChannelCount));
1640     }
1641 }
1642 
Validate(const WorkloadInfo & workloadInfo) const1643 void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1644 {
1645     const std::string descriptorName{"FakeQuantizationQueueDescriptor"};
1646 
1647     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1648     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1649 
1650     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1651     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1652 
1653     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 2, "input");
1654     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 2, "output");
1655 
1656     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo,  descriptorName, "input", "output");
1657 
1658     if (m_Parameters.m_Min > m_Parameters.m_Max)
1659     {
1660         throw InvalidArgumentException(descriptorName + ": min cannot be greater than max");
1661     }
1662 }
1663 
Validate(const WorkloadInfo & workloadInfo) const1664 void InstanceNormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1665 {
1666     const std::string descriptorName{"InstanceNormalizationQueueDescriptor"};
1667 
1668     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1669     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1670 
1671     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1672     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1673 
1674     if (inputTensorInfo.GetNumDimensions() > 4)
1675     {
1676         throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1677     }
1678 
1679     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1680 
1681     // Check the supported data types
1682     std::vector<DataType> supportedTypes =
1683         {
1684             DataType::BFloat16,
1685             DataType::Float32,
1686             DataType::Float16
1687         };
1688 
1689     ValidateDataTypes(inputTensorInfo,  supportedTypes, descriptorName);
1690     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1691 }
1692 
Validate(const WorkloadInfo & workloadInfo) const1693 void L2NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1694 {
1695     const std::string descriptorName{"L2NormalizationQueueDescriptor"};
1696 
1697     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1698     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1699 
1700     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1701     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1702 
1703     if (inputTensorInfo.GetNumDimensions() > 4)
1704     {
1705         throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
1706     }
1707 
1708     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1709 
1710     // Check the supported data types
1711     std::vector<DataType> supportedTypes =
1712     {
1713         DataType::BFloat16,
1714         DataType::Float32,
1715         DataType::Float16,
1716         DataType::QAsymmS8,
1717         DataType::QAsymmU8,
1718         DataType::QSymmS16
1719     };
1720 
1721     ValidateDataTypes(inputTensorInfo,  supportedTypes, descriptorName);
1722     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1723 }
1724 
Validate(const WorkloadInfo & workloadInfo) const1725 void LogSoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1726 {
1727     const std::string descriptorName{"LogSoftmaxQueueDescriptor"};
1728 
1729     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1730     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1731 
1732     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1733     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1734 
1735     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1736 
1737     std::vector<DataType> supportedTypes =
1738     {
1739         DataType::BFloat16,
1740         DataType::Float32,
1741         DataType::Float16,
1742     };
1743 
1744     ValidateDataTypes(inputTensorInfo,  supportedTypes, descriptorName);
1745     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1746 }
1747 
Validate(const WorkloadInfo & workloadInfo) const1748 void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1749 {
1750     const std::string descriptorName{"ConstantQueueDescriptor"};
1751 
1752     ValidateNumInputs(workloadInfo,  descriptorName, 0);
1753     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1754 
1755     if (!m_LayerOutput)
1756     {
1757         throw InvalidArgumentException(descriptorName + ": No const input specified.");
1758     }
1759 
1760     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1761     ValidateTensorShapesMatch(m_LayerOutput->GetTensorInfo(), outputTensorInfo, descriptorName, "constant", "output");
1762 
1763     // Check the supported data types
1764     std::vector<DataType> supportedTypes =
1765     {
1766         DataType::BFloat16,
1767         DataType::Float32,
1768         DataType::Float16,
1769         DataType::QAsymmS8,
1770         DataType::QAsymmU8,
1771         DataType::QSymmS8,
1772         DataType::QSymmS16,
1773         DataType::Signed32
1774     };
1775 
1776     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1777 }
1778 
Validate(const WorkloadInfo & workloadInfo) const1779 void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1780 {
1781     const std::string descriptorName{"ReshapeQueueDescriptor"};
1782 
1783     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1784     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1785 
1786     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1787     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1788 
1789     ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1790 
1791     // Check the supported data types
1792     std::vector<DataType> supportedTypes =
1793     {
1794         DataType::BFloat16,
1795         DataType::Float32,
1796         DataType::Float16,
1797         DataType::QAsymmS8,
1798         DataType::QAsymmU8,
1799         DataType::QSymmS16,
1800         DataType::Signed32,
1801         DataType::Boolean
1802     };
1803 
1804     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1805     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1806 }
1807 
Validate(const WorkloadInfo & workloadInfo) const1808 void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1809 {
1810     const std::string descriptorName{"SpaceToBatchNdQueueDescriptor"};
1811 
1812     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1813     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1814 
1815     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1816     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1817 
1818     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 4, "input");
1819     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1820 
1821     if (m_Parameters.m_BlockShape.size() != 2)
1822     {
1823         throw InvalidArgumentException(descriptorName + ": Block Shape must contain 2 spatial dimensions.");
1824     }
1825 
1826     if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
1827     {
1828         throw InvalidArgumentException(descriptorName + ": Pad List must contain the same number of "
1829                                        "dimensions as Block Shape.");
1830     }
1831 
1832     const TensorShape& inputShape = inputTensorInfo.GetShape();
1833 
1834     std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
1835     std::pair<unsigned int, unsigned int> widthPad  = m_Parameters.m_PadList[1];
1836 
1837     DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1838 
1839     const unsigned int inputWidth  = inputShape[dimensionIndices.GetWidthIndex()] +
1840                                      widthPad.first + widthPad.second;
1841     const unsigned int inputHeight = inputShape[dimensionIndices.GetHeightIndex()] +
1842                                      heightPad.first + heightPad.second;
1843 
1844     const unsigned int numInputElements  = inputShape[0] * inputHeight * inputWidth *
1845                                            inputShape[dimensionIndices.GetChannelsIndex()];
1846     const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
1847 
1848     if (numOutputElements != numInputElements)
1849     {
1850         throw InvalidArgumentException(descriptorName + ": Input tensor has " +
1851             to_string(numInputElements) + " after padding but output tensor has " +
1852             to_string(numOutputElements) + " elements.");
1853     }
1854 
1855     if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
1856     {
1857         throw InvalidArgumentException(descriptorName + ": Input shape after padding must be "
1858                                        "divisible by Block Shape in all spatial dimensions");
1859     }
1860 
1861     std::vector<DataType> supportedTypes =
1862     {
1863         DataType::BFloat16,
1864         DataType::Float16,
1865         DataType::Float32,
1866         DataType::QAsymmS8,
1867         DataType::QAsymmU8,
1868         DataType::QSymmS16
1869     };
1870 
1871     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
1872     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1873 }
1874 
Validate(const WorkloadInfo & workloadInfo) const1875 void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1876 {
1877     const std::string descriptorName{"SpaceToDepthQueueDescriptor"};
1878 
1879     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1880     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1881 
1882     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1883     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1884 
1885     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 4, "input");
1886     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
1887 
1888     std::vector<DataType> supportedTypes =
1889     {
1890         DataType::BFloat16,
1891         DataType::Float32,
1892         DataType::Float16,
1893         DataType::QAsymmS8,
1894         DataType::QAsymmU8,
1895         DataType::QSymmS16
1896     };
1897 
1898     ValidateDataTypes(inputTensorInfo,  supportedTypes, descriptorName);
1899     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
1900 
1901     ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1902 
1903     if (m_Parameters.m_BlockSize == 0)
1904     {
1905         throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
1906     }
1907 
1908     DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
1909     const unsigned int wIndex = dimensionIndices.GetWidthIndex();
1910     const unsigned int hIndex = dimensionIndices.GetHeightIndex();
1911     const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
1912 
1913     const TensorShape& inputShape = inputTensorInfo.GetShape();
1914     if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex]  % m_Parameters.m_BlockSize != 0)
1915     {
1916         throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
1917                                        "by block size in all spatial dimensions");
1918     }
1919 
1920     const TensorShape& outputShape = outputTensorInfo.GetShape();
1921     if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
1922     {
1923         throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
1924                                        "must be divisible by the square of block size." );
1925     }
1926 }
1927 
Validate(const WorkloadInfo & workloadInfo) const1928 void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1929 {
1930     const std::string descriptorName{"FloorQueueDescriptor"};
1931 
1932     ValidateNumInputs(workloadInfo,  descriptorName, 1);
1933     ValidateNumOutputs(workloadInfo, descriptorName, 1);
1934 
1935     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
1936     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
1937 
1938     std::vector<DataType> supportedTypes =
1939     {
1940         DataType::BFloat16,
1941         DataType::Float32,
1942         DataType::Float16,
1943         DataType::QSymmS16
1944     };
1945 
1946     ValidateDataTypes(inputTensorInfo,  supportedTypes, descriptorName);
1947     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1948     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1949     ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
1950 }
1951 
Validate(const WorkloadInfo & workloadInfo) const1952 void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
1953 {
1954     // ported from android/ml/nn/common/operations/LSTM.cpp CheckInputTensorDimensions()
1955 
1956     const std::string descriptorName{"LstmQueueDescriptor"};
1957 
1958     // check dimensions of all inputs and outputs
1959     if (workloadInfo.m_InputTensorInfos.size() != 3)
1960     {
1961         throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
1962     }
1963     if (workloadInfo.m_OutputTensorInfos.size() != 4)
1964     {
1965         throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
1966     }
1967 
1968     std::vector<DataType> supportedTypes =
1969     {
1970         DataType::BFloat16,
1971         DataType::Float16,
1972         DataType::Float32,
1973         DataType::QSymmS16
1974     };
1975 
1976     // check for supported type of one input and match them with all the other input and output
1977     ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
1978 
1979     // type matches all other inputs
1980     for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i)
1981     {
1982         ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1983                                      workloadInfo.m_InputTensorInfos[i],
1984                                      descriptorName,
1985                                      "input_0",
1986                                      "input_" + std::to_string(i));
1987     }
1988     // type matches all other outputs
1989     for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i)
1990     {
1991         ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
1992                                      workloadInfo.m_OutputTensorInfos[i],
1993                                      "LstmQueueDescriptor",
1994                                      "input_0",
1995                                      "output_" + std::to_string(i));
1996     }
1997 
1998     // Making sure clipping parameters have valid values.
1999     // == 0 means no clipping
2000     //  > 0 means clipping
2001     if (m_Parameters.m_ClippingThresCell < 0.0f)
2002     {
2003         throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
2004     }
2005     if (m_Parameters.m_ClippingThresProj < 0.0f)
2006     {
2007         throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
2008     }
2009 
2010     // Inferring batch size, number of outputs and number of cells from the inputs.
2011     const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
2012     const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
2013     ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
2014     const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
2015     ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
2016     const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
2017 
2018     // input tensor
2019     ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 2, (n_batch * n_input),
2020                                 descriptorName + " input_0");
2021     // outputStateInTensor
2022     ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
2023                                 descriptorName + " input_1");
2024     // outputStateInTensor
2025     ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
2026                                 descriptorName + " input_2");
2027     // scratchBufferTensor
2028     unsigned int scratchBufferSize = m_Parameters.m_CifgEnabled ? n_cell * 3 : n_cell * 4;
2029     ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 2, (n_batch * scratchBufferSize),
2030                                 descriptorName + " output_0");
2031     // outputStateOutTensor
2032     ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[1], 2, (n_batch * n_output),
2033                                 descriptorName + " output_1");
2034     // cellStateOutTensor
2035     ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 2, (n_batch * n_cell),
2036                                 descriptorName + " output_2");
2037     // outputTensor
2038     ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output),
2039                                 descriptorName + " output_3");
2040 
2041     // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
2042     if ( m_InputToInputWeights )
2043     {
2044         ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
2045                                       (n_cell * n_input), "InputLayerNormWeights");
2046     }
2047 
2048     ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
2049     ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
2050                                   (n_cell * n_input), "InputToForgetWeights");
2051 
2052     ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
2053     ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
2054                                   (n_cell * n_input), "InputToCellWeights");
2055 
2056     if ( m_RecurrentToInputWeights )
2057     {
2058         ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
2059                                       (n_cell * n_output), "RecurrentToInputWeights");
2060     }
2061 
2062     ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
2063     ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
2064                                   (n_cell * n_output), "RecurrentToForgetWeights");
2065 
2066     ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
2067     ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
2068                                   (n_cell * n_output), "RecurrentToCellWeights");
2069 
2070     // Make sure the input-gate's parameters are either both present (regular
2071     // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
2072     bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
2073                                      !m_Parameters.m_CifgEnabled) ||
2074                                      (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
2075                                      m_Parameters.m_CifgEnabled));
2076     if (!cifg_weights_all_or_none)
2077     {
2078         throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
2079                                        "RecurrentToInputWeights must either both be present (regular LSTM) "
2080                                        "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
2081                                        "accordingly.");
2082     }
2083 
2084     if ( m_CellToInputWeights )
2085     {
2086         ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
2087                                       n_cell, "CellToInputWeights");
2088     }
2089     if ( m_CellToForgetWeights )
2090     {
2091         ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
2092                                       n_cell, "CellToForgetWeights");
2093     }
2094     if ( m_CellToOutputWeights )
2095     {
2096         ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
2097                                       n_cell, "CellToOutputWeights");
2098     }
2099 
2100     // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
2101     bool peephole_weights_all_or_none =
2102             (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) &&  m_CellToForgetWeights
2103             && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
2104             || ( !m_CellToInputWeights && !m_CellToForgetWeights
2105             && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
2106     if (!peephole_weights_all_or_none)
2107     {
2108         throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
2109     }
2110 
2111     // Make sure the input gate bias is present only when not a CIFG-LSTM.
2112     if (m_Parameters.m_CifgEnabled)
2113     {
2114         if (m_InputGateBias)
2115         {
2116             throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
2117         }
2118     }
2119     else
2120     {
2121         if (!m_InputGateBias)
2122         {
2123             throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
2124                                            "must be present.");
2125         }
2126         ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
2127                                       n_cell, "InputGateBias");
2128     }
2129 
2130     ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
2131     ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
2132 
2133     ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
2134     ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
2135 
2136     ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
2137     ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
2138 
2139     if (m_ProjectionWeights)
2140     {
2141         ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
2142                                       (n_cell * n_output), "ProjectionWeights");
2143     }
2144     if (m_ProjectionBias)
2145     {
2146         ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
2147     }
2148 
2149     // Making sure the projection tensors are consistent:
2150     // 1) If projection weight is not present, then projection bias should not be
2151     // present.
2152     // 2) If projection weight is present, then projection bias is optional.
2153     bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
2154                                         !m_Parameters.m_ProjectionEnabled)
2155                                         || (m_ProjectionWeights && !m_ProjectionBias &&
2156                                         m_Parameters.m_ProjectionEnabled)
2157                                         || (m_ProjectionWeights && m_ProjectionBias &&
2158                                         m_Parameters.m_ProjectionEnabled));
2159     if (!projecton_tensors_consistent)
2160     {
2161         throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
2162     }
2163 
2164     // The four layer normalization weights either all have values or none of them have values. Additionally, if
2165     // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
2166     // either all have values or none of them have values. Layer normalization is used when the values of all the
2167     // layer normalization weights are present
2168     if (m_InputLayerNormWeights)
2169     {
2170         ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
2171     }
2172     if (m_ForgetLayerNormWeights)
2173     {
2174         ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2175     }
2176     if (m_CellLayerNormWeights)
2177     {
2178         ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2179     }
2180     if (m_OutputLayerNormWeights)
2181     {
2182         ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2183     }
2184 
2185     if (m_Parameters.m_LayerNormEnabled)
2186     {
2187         if (!m_Parameters.m_CifgEnabled)
2188         {
2189             if (!m_InputLayerNormWeights)
2190             {
2191                 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
2192                                                "disabled but InputLayerNormWeights are not present");
2193             }
2194             ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
2195                                           1, n_cell, "InputLayerNormWeights");
2196         }
2197         else if (m_InputLayerNormWeights)
2198         {
2199             throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
2200                                            "enabled");
2201         }
2202 
2203         ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
2204                         "ForgetLayerNormWeights");
2205         ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
2206 
2207         ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
2208                         "OutputLayerNormWeights");
2209         ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
2210 
2211         ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
2212                         "CellLayerNormWeights");
2213         ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
2214     }
2215     else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
2216     {
2217         throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
2218                                        "normalisation weights are present.");
2219     }
2220 }
2221 
Validate(const WorkloadInfo & workloadInfo) const2222 void ConvertFp32ToFp16QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2223 {
2224     const std::string descriptorName{"ConvertFp32ToFp16QueueDescriptor"};
2225 
2226     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2227     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2228 
2229     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2230     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2231 
2232     if (inputTensorInfo.GetDataType() != DataType::Float32)
2233     {
2234         throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float32.");
2235     }
2236 
2237     if (outputTensorInfo.GetDataType() != DataType::Float16)
2238     {
2239         throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float16.");
2240     }
2241 
2242     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2243 }
2244 
Validate(const WorkloadInfo & workloadInfo) const2245 void ConvertFp16ToFp32QueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2246 {
2247     const std::string descriptorName{"ConvertFp16ToFp32QueueDescriptor"};
2248 
2249     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2250     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2251 
2252     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2253     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2254 
2255     if (inputTensorInfo.GetDataType() != DataType::Float16)
2256     {
2257         throw InvalidArgumentException(descriptorName + ": Input tensor type must be Float16.");
2258     }
2259 
2260     if (outputTensorInfo.GetDataType() != DataType::Float32)
2261     {
2262         throw InvalidArgumentException(descriptorName + ": Output tensor type must be Float32.");
2263     }
2264 
2265     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2266 }
2267 
Validate(const WorkloadInfo & workloadInfo) const2268 void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2269 {
2270     const std::string descriptorName{"DivisionQueueDescriptor"};
2271 
2272     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2273     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2274 
2275     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2276     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2277     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2278 
2279     std::vector<DataType> supportedTypes =
2280     {
2281         DataType::BFloat16,
2282         DataType::Float16,
2283         DataType::Float32,
2284         DataType::QAsymmS8,
2285         DataType::QAsymmU8,
2286         DataType::QSymmS16,
2287         DataType::Signed32
2288     };
2289 
2290     ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2291     ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2292     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2293 
2294     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2295                                        inputTensorInfo1,
2296                                        outputTensorInfo,
2297                                        descriptorName,
2298                                        "input_0",
2299                                        "input_1");
2300 }
2301 
Validate(const WorkloadInfo & workloadInfo) const2302 void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2303 {
2304     const std::string descriptorName{"SubtractionQueueDescriptor"};
2305 
2306     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2307     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2308 
2309     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2310     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2311     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2312 
2313     std::vector<DataType> supportedTypes =
2314     {
2315         DataType::BFloat16,
2316         DataType::Float16,
2317         DataType::Float32,
2318         DataType::QAsymmS8,
2319         DataType::QAsymmU8,
2320         DataType::QSymmS16,
2321         DataType::Signed32,
2322     };
2323 
2324     ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2325     ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2326     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2327 
2328     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2329                                        inputTensorInfo1,
2330                                        outputTensorInfo,
2331                                        descriptorName,
2332                                        "input_0",
2333                                        "input_1");
2334 }
2335 
Validate(const WorkloadInfo & workloadInfo) const2336 void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2337 {
2338     const std::string descriptorName{"MaximumQueueDescriptor"};
2339 
2340     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2341     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2342 
2343     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2344     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2345     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2346 
2347     std::vector<DataType> supportedTypes =
2348     {
2349         DataType::BFloat16,
2350         DataType::Float16,
2351         DataType::Float32,
2352         DataType::QAsymmS8,
2353         DataType::QAsymmU8,
2354         DataType::QSymmS16,
2355         DataType::Signed32
2356     };
2357 
2358     ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2359     ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2360     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2361 
2362     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2363                                        inputTensorInfo1,
2364                                        outputTensorInfo,
2365                                        descriptorName,
2366                                        "input_0",
2367                                        "input_1");
2368 }
2369 
Validate(const WorkloadInfo & workloadInfo) const2370 void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2371 {
2372     const std::string descriptorName{"MeanQueueDescriptor"};
2373 
2374     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2375     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2376 
2377     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2378     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2379 
2380     std::vector<DataType> supportedTypes =
2381     {
2382         DataType::BFloat16,
2383         DataType::Float32,
2384         DataType::Float16,
2385         DataType::QAsymmS8,
2386         DataType::QAsymmU8,
2387         DataType::QSymmS16
2388     };
2389 
2390     // First check if input tensor data type is supported, then
2391     // check if this data type matches the output tensor data type
2392     ValidateDataTypes(inputTensorInfo,  supportedTypes, descriptorName);
2393     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2394 
2395     if (m_Parameters.m_KeepDims)
2396     {
2397         ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2398     }
2399     else if (m_Parameters.m_Axis.empty())
2400     {
2401         ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
2402     }
2403     else
2404     {
2405         unsigned int outputDim =
2406             inputTensorInfo.GetNumDimensions() - armnn::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
2407         ValidateTensorNumDimensions(outputTensorInfo,
2408                                     descriptorName,
2409                                     outputDim > 0 ? outputDim : 1,
2410                                     "output");
2411     }
2412 }
2413 
Validate(const WorkloadInfo & workloadInfo) const2414 void PadQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2415 {
2416     const std::string descriptorName{"PadQueueDescriptor"};
2417 
2418     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2419     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2420 
2421     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2422     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2423 
2424     // input and output should have the same number of dimensions
2425     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, inputTensorInfo.GetNumDimensions(), "output");
2426 
2427     // there should be entry in the pad list for each dimension in the input tensor
2428     if (m_Parameters.m_PadList.size() != inputTensorInfo.GetNumDimensions()) {
2429         throw InvalidArgumentException(descriptorName + ":Pad List should contain the same number of entries "
2430                                        "as there are dimensions in the input tensor that is " +
2431                                        std::to_string(inputTensorInfo.GetNumDimensions()) + " entries " +
2432                                        " not " + std::to_string(m_Parameters.m_PadList.size()) + " entries.");
2433     }
2434 }
2435 
Validate(const WorkloadInfo & workloadInfo) const2436 void QuantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2437 {
2438     const std::string descriptorName{"QuantizeQueueDescriptor"};
2439 
2440     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2441     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2442 
2443     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2444     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2445 
2446     std::vector<DataType> supportedTypes =
2447     {
2448         DataType::BFloat16,
2449         DataType::Float32,
2450         DataType::Float16,
2451         DataType::QSymmS8,
2452         DataType::QAsymmS8,
2453         DataType::QAsymmU8,
2454         DataType::QSymmS16
2455     };
2456 
2457     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2458 
2459     if (!IsQuantizedType(outputTensorInfo.GetDataType()))
2460     {
2461         throw InvalidArgumentException(descriptorName + ": Output of quantized layer must be quantized type.");
2462     }
2463 }
2464 
Validate(const WorkloadInfo & workloadInfo) const2465 void BatchToSpaceNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2466 {
2467     const std::string descriptorName{"BatchToSpaceNdQueueDescriptor"};
2468 
2469     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2470     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2471 
2472     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2473     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2474 
2475     std::vector<DataType> supportedTypes =
2476     {
2477         DataType::BFloat16,
2478         DataType::Float32,
2479         DataType::Float16,
2480         DataType::QAsymmS8,
2481         DataType::QAsymmU8,
2482         DataType::QSymmS16
2483     };
2484 
2485     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2486     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2487 }
2488 
Validate(const WorkloadInfo & workloadInfo) const2489 void StridedSliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2490 {
2491     const std::string descriptorName{"StridedSliceQueueDescriptor"};
2492 
2493     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2494     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2495 
2496     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2497     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2498 
2499     std::vector<DataType> supportedTypes =
2500     {
2501         DataType::BFloat16,
2502         DataType::Float16,
2503         DataType::Float32,
2504         DataType::QAsymmS8,
2505         DataType::QAsymmU8,
2506         DataType::QSymmS16
2507     };
2508 
2509     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2510     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2511 
2512     ValidateTensorQuantizationSpace(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2513 
2514     const uint32_t rank = inputTensorInfo.GetNumDimensions();
2515     if (rank > 4)
2516     {
2517         throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
2518     }
2519 
2520     // Begin, End & Stride length must be of rank(input0)
2521     if (m_Parameters.m_Begin.size() != rank)
2522     {
2523         throw InvalidArgumentException(descriptorName + ": Begin length must be of rank " + std::to_string(rank));
2524     }
2525 
2526     if (m_Parameters.m_End.size() != rank)
2527     {
2528         throw InvalidArgumentException(descriptorName + ": End length must be of rank " + std::to_string(rank));
2529     }
2530 
2531     if (m_Parameters.m_Stride.size() != rank)
2532     {
2533         throw InvalidArgumentException(descriptorName + ": Stride length must be of rank " + std::to_string(rank));
2534     }
2535 
2536     // Stride entries must be non-zero
2537     for (auto& stride : m_Parameters.m_Stride)
2538     {
2539         if (stride == 0)
2540         {
2541             throw InvalidArgumentException(descriptorName + ": Stride entries must be non-zero.");
2542         }
2543     }
2544 }
2545 
Validate(const WorkloadInfo & workloadInfo) const2546 void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2547 {
2548     const std::string descriptorName{"MinimumQueueDescriptor"};
2549 
2550     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2551     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2552 
2553     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2554     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2555     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2556 
2557     std::vector<DataType> supportedTypes =
2558     {
2559         DataType::BFloat16,
2560         DataType::Float16,
2561         DataType::Float32,
2562         DataType::QAsymmS8,
2563         DataType::QAsymmU8,
2564         DataType::QSymmS16,
2565         DataType::Signed32
2566     };
2567 
2568     ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2569     ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2570     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2571 
2572     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2573                                        inputTensorInfo1,
2574                                        outputTensorInfo,
2575                                        descriptorName,
2576                                        "input_0",
2577                                        "input_1");
2578 }
2579 
Validate(const WorkloadInfo & workloadInfo) const2580 void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2581 {
2582     const std::string descriptorName{"DebugQueueDescriptor"};
2583 
2584     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2585     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2586 }
2587 
Validate(const WorkloadInfo & workloadInfo) const2588 void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2589 {
2590     const std::string descriptorName{"EqualQueueDescriptor"};
2591 
2592     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2593     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2594 
2595     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2596     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2597     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2598 
2599     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2600                                        inputTensorInfo1,
2601                                        outputTensorInfo,
2602                                        descriptorName,
2603                                        "input_0",
2604                                        "input_1");
2605 
2606     if (outputTensorInfo.GetDataType() != DataType::Boolean)
2607     {
2608         throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2609     }
2610 }
2611 
Validate(const WorkloadInfo & workloadInfo) const2612 void GreaterQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2613 {
2614     const std::string descriptorName{"GreaterQueueDescriptor"};
2615 
2616     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2617     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2618 
2619     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2620     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2621     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2622 
2623     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
2624                                        inputTensorInfo1,
2625                                        outputTensorInfo,
2626                                        descriptorName,
2627                                        "input_0",
2628                                        "input_1");
2629 
2630     if (outputTensorInfo.GetDataType() != DataType::Boolean)
2631     {
2632         throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
2633     }
2634 }
2635 
Validate(const WorkloadInfo & workloadInfo) const2636 void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2637 {
2638     const std::string descriptorName{"RsqrtQueueDescriptor"};
2639 
2640     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2641     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2642 
2643     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2644     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2645 
2646     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2647 
2648     std::vector<DataType> supportedTypes =
2649     {
2650         DataType::BFloat16,
2651         DataType::Float16,
2652         DataType::Float32,
2653         DataType::QAsymmS8,
2654         DataType::QAsymmU8,
2655         DataType::QSymmS16
2656     };
2657 
2658     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2659     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2660 }
2661 
Validate(const WorkloadInfo & workloadInfo) const2662 void GatherNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2663 {
2664     const std::string descriptorName{"GatherNdQueueDescriptor"};
2665 
2666     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2667     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2668 
2669     const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2670     if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2671     {
2672         throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2673     }
2674 
2675     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2676     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2677 
2678     std::vector<DataType> supportedTypes =
2679             {
2680                     DataType::BFloat16,
2681                     DataType::Float16,
2682                     DataType::Float32,
2683                     DataType::QAsymmS8,
2684                     DataType::QAsymmU8,
2685                     DataType::QSymmS16,
2686                     DataType::Signed32,
2687             };
2688 
2689     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2690 
2691     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2692 
2693     unsigned int outputDim  = outputTensorInfo.GetNumDimensions();
2694     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2695 }
2696 
Validate(const WorkloadInfo & workloadInfo) const2697 void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2698 {
2699     const std::string descriptorName{"GatherQueueDescriptor"};
2700 
2701     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2702     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2703 
2704     const TensorInfo& indicesTensorInfo = workloadInfo.m_InputTensorInfos[1];
2705     if (indicesTensorInfo.GetDataType() != DataType::Signed32)
2706     {
2707         throw InvalidArgumentException(descriptorName + ": Indices tensor type must be Int32.");
2708     }
2709 
2710     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2711     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2712 
2713     std::vector<DataType> supportedTypes =
2714     {
2715         DataType::BFloat16,
2716         DataType::Float16,
2717         DataType::Float32,
2718         DataType::QAsymmS8,
2719         DataType::QAsymmU8,
2720         DataType::QSymmS16,
2721         DataType::Signed32,
2722     };
2723 
2724     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2725 
2726     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
2727 
2728     unsigned int outputDim  = inputTensorInfo.GetNumDimensions() + indicesTensorInfo.GetNumDimensions() - 1;
2729     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, outputDim, "output");
2730 }
2731 
Validate(const WorkloadInfo & workloadInfo) const2732 void DetectionPostProcessQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2733 {
2734     const std::string& descriptorName{"DetectionPostProcessQueueDescriptor"};
2735 
2736     ValidateNumInputs(workloadInfo, descriptorName, 2);
2737 
2738     if (workloadInfo.m_OutputTensorInfos.size() != 4)
2739     {
2740         throw InvalidArgumentException(descriptorName + ": Requires exactly four outputs. " +
2741                                        to_string(workloadInfo.m_OutputTensorInfos.size()) + " has been provided.");
2742     }
2743 
2744     if (m_Anchors == nullptr)
2745     {
2746         throw InvalidArgumentException(descriptorName + ": Anchors tensor descriptor is missing.");
2747     }
2748 
2749     const TensorInfo& boxEncodingsInfo =  workloadInfo.m_InputTensorInfos[0];
2750     const TensorInfo& scoresInfo       =  workloadInfo.m_InputTensorInfos[1];
2751     const TensorInfo& anchorsInfo      = m_Anchors->GetTensorInfo();
2752 
2753     const TensorInfo& detectionBoxesInfo   = workloadInfo.m_OutputTensorInfos[0];
2754     const TensorInfo& detectionClassesInfo = workloadInfo.m_OutputTensorInfos[1];
2755     const TensorInfo& detectionScoresInfo  = workloadInfo.m_OutputTensorInfos[2];
2756     const TensorInfo& numDetectionsInfo    = workloadInfo.m_OutputTensorInfos[3];
2757 
2758     ValidateTensorNumDimensions(boxEncodingsInfo, descriptorName, 3, "box encodings");
2759     ValidateTensorNumDimensions(scoresInfo, descriptorName, 3, "scores");
2760     ValidateTensorNumDimensions(anchorsInfo, descriptorName, 2, "anchors");
2761 
2762     const std::vector<DataType> supportedInputTypes =
2763     {
2764         DataType::BFloat16,
2765         DataType::Float32,
2766         DataType::Float16,
2767         DataType::QAsymmS8,
2768         DataType::QAsymmU8,
2769         DataType::QSymmS16
2770     };
2771 
2772     ValidateDataTypes(boxEncodingsInfo, supportedInputTypes, descriptorName);
2773     ValidateDataTypes(scoresInfo, supportedInputTypes, descriptorName);
2774     ValidateDataTypes(anchorsInfo, supportedInputTypes, descriptorName);
2775 
2776     ValidateTensorNumDimensions(detectionBoxesInfo, descriptorName, 3, "detection boxes");
2777     ValidateTensorNumDimensions(detectionScoresInfo, descriptorName, 2, "detection scores");
2778     ValidateTensorNumDimensions(detectionClassesInfo, descriptorName, 2, "detection classes");
2779     ValidateTensorNumDimensions(numDetectionsInfo, descriptorName, 1, "num detections");
2780 
2781     // NOTE: Output is always Float32 regardless of input type
2782     ValidateTensorDataType(detectionBoxesInfo, DataType::Float32, descriptorName, "detection boxes");
2783     ValidateTensorDataType(detectionScoresInfo, DataType::Float32, descriptorName, "detection scores");
2784     ValidateTensorDataType(detectionClassesInfo, DataType::Float32, descriptorName, "detection classes");
2785     ValidateTensorDataType(numDetectionsInfo, DataType::Float32, descriptorName, "num detections");
2786 
2787     if (m_Parameters.m_NmsIouThreshold <= 0.0f || m_Parameters.m_NmsIouThreshold > 1.0f)
2788     {
2789         throw InvalidArgumentException(descriptorName + ": Intersection over union threshold "
2790                                        "must be positive and less than or equal to 1.");
2791     }
2792 
2793     if (scoresInfo.GetShape()[2] != m_Parameters.m_NumClasses + 1)
2794     {
2795         throw InvalidArgumentException(descriptorName + ": Number of classes with background "
2796                                        "should be equal to number of classes + 1.");
2797     }
2798 }
2799 
Validate(const WorkloadInfo & workloadInfo) const2800 void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2801 {
2802     const std::string& descriptorName{"DequantizeQueueDescriptor"};
2803 
2804     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2805     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2806 
2807     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2808     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2809 
2810     std::vector<DataType> inputSupportedTypes =
2811     {
2812             DataType::QAsymmS8,
2813             DataType::QAsymmU8,
2814             DataType::QSymmS8,
2815             DataType::QSymmS16,
2816             DataType::Float16
2817     };
2818     ValidateDataTypes(inputTensorInfo, inputSupportedTypes, descriptorName);
2819 
2820     std::vector<DataType> outputSupportedTypes =
2821     {
2822         DataType::BFloat16,
2823         DataType::Float32,
2824         DataType::Float16
2825     };
2826 
2827     ValidateDataTypes(outputTensorInfo, outputSupportedTypes, descriptorName);
2828 }
2829 
Validate(const WorkloadInfo & workloadInfo) const2830 void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2831 {
2832     const std::string& descriptorName{"MergeQueueDescriptor"};
2833 
2834     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2835     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2836 
2837     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2838     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2839     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2840 
2841     ValidateTensorShapesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2842     ValidateTensorShapesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2843 
2844     ValidateTensorDataTypesMatch(inputTensorInfo0, inputTensorInfo1, descriptorName, "input_0", "input_1");
2845     ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input_0", "output");
2846 }
2847 
Validate(const WorkloadInfo & workloadInfo) const2848 void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2849 {
2850     const std::string& descriptorName{"ShapeQueueDescriptor"};
2851 
2852     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2853     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2854 
2855     const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
2856     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2857 
2858     std::vector<DataType> supportedTypes =
2859     {
2860         DataType::BFloat16,
2861         DataType::Float16,
2862         DataType::Float32,
2863         DataType::QAsymmS8,
2864         DataType::QAsymmU8,
2865         DataType::QAsymmS8,
2866         DataType::QSymmS8,
2867         DataType::QSymmS16,
2868         DataType::Signed32
2869     };
2870 
2871     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2872     ValidateDataTypes(outputTensorInfo, {DataType::Signed32}, descriptorName);
2873 }
2874 
Validate(const WorkloadInfo & workloadInfo) const2875 void SwitchQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2876 {
2877     const std::string& descriptorName{"SwitchQueueDescriptor"};
2878 
2879     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2880     ValidateNumOutputs(workloadInfo, descriptorName, 2);
2881 
2882     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
2883     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
2884 
2885     const TensorInfo& outputTensorInfo0 = workloadInfo.m_OutputTensorInfos[0];
2886     const TensorInfo& outputTensorInfo1 = workloadInfo.m_OutputTensorInfos[1];
2887 
2888     std::vector<DataType> supportedTypes =
2889     {
2890         DataType::BFloat16,
2891         DataType::Float32,
2892         DataType::QAsymmS8,
2893         DataType::QAsymmU8,
2894         DataType::QSymmS16
2895     };
2896 
2897     ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
2898     ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
2899 
2900     ValidateDataTypes(outputTensorInfo0, supportedTypes, descriptorName);
2901     ValidateDataTypes(outputTensorInfo1, supportedTypes, descriptorName);
2902 
2903     ValidateTensorShapesMatch(inputTensorInfo0,
2904                               outputTensorInfo0,
2905                               descriptorName,
2906                               "input_0",
2907                               "output_0");
2908 
2909     ValidateTensorShapesMatch(inputTensorInfo0,
2910                               outputTensorInfo1,
2911                               descriptorName,
2912                               "input_0",
2913                               "output_1");
2914 }
2915 
Validate(const WorkloadInfo &) const2916 void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& /*workloadInfo*/) const
2917 {
2918     // This is internally generated so it should not need validation.
2919 }
2920 
Validate(const WorkloadInfo & workloadInfo) const2921 void PreluQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2922 {
2923     const std::string& descriptorName{"PreluQueueDescriptor"};
2924 
2925     ValidateNumInputs(workloadInfo,  descriptorName, 2);
2926     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2927 
2928     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2929     const TensorInfo& alphaTensorInfo  = workloadInfo.m_InputTensorInfos[1];
2930     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2931 
2932     std::vector<DataType> supportedTypes
2933     {
2934         DataType::BFloat16,
2935         DataType::Float16,
2936         DataType::Float32,
2937         DataType::QAsymmS8,
2938         DataType::QAsymmU8,
2939         DataType::QSymmS16
2940     };
2941 
2942     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
2943     ValidateDataTypes(alphaTensorInfo, supportedTypes, descriptorName);
2944 
2945     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
2946 
2947     ValidateTensorDataTypesMatch(inputTensorInfo, alphaTensorInfo,  descriptorName, "input", "alpha");
2948     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "ouptut");
2949 
2950     ValidateBroadcastTensorShapesMatch(inputTensorInfo,
2951                                        alphaTensorInfo,
2952                                        outputTensorInfo,
2953                                        descriptorName,
2954                                        "input",
2955                                        "alpha");
2956 }
2957 
Validate(const WorkloadInfo & workloadInfo) const2958 void TransposeConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
2959 {
2960     const std::string descriptorName{"TransposeConvolution2dQueueDescriptor"};
2961 
2962     ValidateNumInputs(workloadInfo,  descriptorName, 1);
2963     ValidateNumOutputs(workloadInfo, descriptorName, 1);
2964 
2965     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
2966     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
2967 
2968     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, 4, "input");
2969     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 4, "output");
2970 
2971     ValidatePointer(m_Weight, descriptorName, "weight");
2972 
2973     const TensorInfo& weightTensorInfo = m_Weight->GetTensorInfo();
2974     ValidateTensorNumDimensions(weightTensorInfo, descriptorName, 4, "weight");
2975 
2976     ValidateWeightDataType(inputTensorInfo, weightTensorInfo, descriptorName);
2977 
2978     Optional<TensorInfo> optionalBiasTensorInfo;
2979     if (m_Parameters.m_BiasEnabled)
2980     {
2981         ValidatePointer(m_Bias, descriptorName, "bias");
2982 
2983         optionalBiasTensorInfo = MakeOptional<TensorInfo>(m_Bias->GetTensorInfo());
2984         const TensorInfo& biasTensorInfo = optionalBiasTensorInfo.value();
2985 
2986         ValidateTensorDataType(biasTensorInfo, GetBiasDataType(inputTensorInfo.GetDataType()), descriptorName, "bias");
2987         ValidateBiasTensorQuantization(biasTensorInfo, inputTensorInfo, weightTensorInfo, descriptorName);
2988     }
2989 
2990     ValidatePerAxisQuantization(inputTensorInfo,
2991                                 outputTensorInfo,
2992                                 weightTensorInfo,
2993                                 optionalBiasTensorInfo,
2994                                 descriptorName);
2995 
2996     std::vector<DataType> supportedTypes =
2997     {
2998         DataType::BFloat16,
2999         DataType::Float32,
3000         DataType::Float16,
3001         DataType::QAsymmS8,
3002         DataType::QAsymmU8,
3003         DataType::QSymmS16
3004     };
3005 
3006     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3007     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3008 }
3009 
Validate(const WorkloadInfo & workloadInfo) const3010 void TransposeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3011 {
3012     const std::string descriptorName{"TransposeQueueDescriptor"};
3013 
3014     ValidateNumInputs(workloadInfo,  descriptorName, 1);
3015     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3016 
3017     const PermutationVector& mapping = m_Parameters.m_DimMappings;
3018 
3019     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
3020     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3021 
3022     ValidateTensorNumDimensions(inputTensorInfo,  descriptorName, mapping.GetSize(), "input");
3023     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, mapping.GetSize(), "output");
3024 
3025     for (unsigned int i = 0u; i < mapping.GetSize(); ++i)
3026     {
3027         if (inputTensorInfo.GetShape()[mapping[i]] != outputTensorInfo.GetShape()[i])
3028         {
3029             throw InvalidArgumentException(descriptorName + ": src dimension " + to_string(mapping[i]) +
3030                                            " (=" + to_string(inputTensorInfo.GetShape()[mapping[i]]) + ") " +
3031                                            "must match dst dimension " + to_string(i) +
3032                                            " (=" + to_string(outputTensorInfo.GetShape()[i]) + ")");
3033         }
3034     }
3035 
3036     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3037 }
3038 
Validate(const WorkloadInfo & workloadInfo) const3039 void ChannelShuffleQueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
3040 {
3041     const std::string descriptorName{"TransposeQueueDescriptor"};
3042 
3043     ValidateNumInputs(workloadInfo, descriptorName, 1);
3044     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3045 
3046     const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3047     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3048 
3049     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3050 }
3051 
Validate(const WorkloadInfo & workloadInfo) const3052 void QLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3053 {
3054     const std::string descriptorName{"QLstmQueueDescriptor"};
3055 
3056     // Validate number of inputs/outputs
3057     ValidateNumInputs(workloadInfo,  descriptorName, 3);
3058     ValidateNumOutputs(workloadInfo, descriptorName, 3);
3059 
3060     // Input/output tensor info
3061     auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3062     auto outputStateInInfo = workloadInfo.m_InputTensorInfos[1];
3063     auto cellStateInInfo = workloadInfo.m_InputTensorInfos[2];
3064 
3065     auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3066     auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3067     auto outputInfo = workloadInfo.m_OutputTensorInfos[2];
3068 
3069     // Supported types for various tensors in QLSTM
3070     std::vector<DataType> inputOutputSupportedTypes =
3071     {
3072         DataType::QAsymmS8
3073     };
3074 
3075     std::vector<DataType> cellStateSupportedTypes =
3076     {
3077         DataType::QSymmS16
3078     };
3079 
3080     std::vector<DataType> weightsSupportedTypes =
3081     {
3082         DataType::QSymmS8
3083     };
3084 
3085     std::vector<DataType> layerNormPeepholeWeightsSupportedTypes =
3086     {
3087         DataType::QSymmS16
3088     };
3089 
3090     std::vector<DataType> biasSupportedTypes =
3091     {
3092         DataType::Signed32
3093     };
3094 
3095     // Validate types of input/output tensors
3096     ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3097     ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3098     ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3099 
3100     ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3101     ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3102     ValidateDataTypes(outputInfo, inputOutputSupportedTypes, descriptorName);
3103 
3104     // Validate matching types of input/output tensors
3105     ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3106     ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3107                                  "outputStateIn", "outputStateOut");
3108     ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3109 
3110     // Infer number of batches, number of units, input size and output size from tensor dimensions
3111     const uint32_t numBatches = inputInfo.GetShape()[0];
3112     const uint32_t inputSize  = inputInfo.GetShape()[1];
3113     const uint32_t outputSize = outputStateInInfo.GetShape()[1];
3114     const uint32_t numUnits = cellStateInInfo.GetShape()[1];
3115 
3116     // Validate number of dimensions and number of elements for input/output tensors
3117     ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3118     ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3119     ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * numUnits), descriptorName + " cellStateIn");
3120 
3121     ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3122     ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * numUnits), descriptorName + " cellStateOut");
3123     ValidateTensorNumDimNumElem(outputInfo, 2, (numBatches * outputSize), descriptorName + " output");
3124 
3125     // Validate number of dimensions and number of elements for MANDATORY weight tensors
3126     ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3127     auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3128     ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (numUnits * inputSize), " InputToForgetWeights");
3129 
3130     ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3131     auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3132     ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (numUnits * inputSize), " InputToCellWeights");
3133 
3134     ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3135     auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3136     ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (numUnits * inputSize), " InputToOutputWeights");
3137 
3138     ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3139     auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3140     ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (numUnits * outputSize),
3141                                 " RecurrentToForgetWeights");
3142 
3143     ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3144     auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3145     ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3146 
3147     ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3148     auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3149     ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (numUnits * outputSize), " RecurrentToCellWeights");
3150 
3151     // Validate data types for MANDATORY weights tensors (all should match each other)
3152     ValidateDataTypes(inputToForgetWeightsInfo, weightsSupportedTypes, descriptorName);
3153 
3154     ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToCellWeightsInfo, descriptorName,
3155                                  "inputToForgetWeights", "inputToCellWeights");
3156     ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3157                                  "inputToForgetWeights", "inputToOutputWeights");
3158 
3159     ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3160                                  "inputToForgetWeights", "recurrentToForgeteights");
3161     ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3162                                  "inputToForgetWeights", "recurrentToCellWeights");
3163     ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3164                                  "inputToForgetWeights", "recurrentToOutputWeights");
3165 
3166     // Validate number of dimensions and number of elements for MANDATORY bias tensors
3167     ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3168     auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3169     ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, numUnits, " ForgetGateBias");
3170 
3171     ValidatePointer(m_CellBias, descriptorName, "CellBias");
3172     auto cellBiasInfo = m_CellBias->GetTensorInfo();
3173     ValidateTensorNumDimNumElem(cellBiasInfo, 1, numUnits, " CellBias");
3174 
3175     ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3176     auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3177     ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, numUnits, " OutputGateBias");
3178 
3179     // Validate data types for MANDATORY bias tensors
3180     ValidateDataTypes(forgetGateBiasInfo, biasSupportedTypes, descriptorName);
3181 
3182     ValidateTensorDataTypesMatch(forgetGateBiasInfo, cellBiasInfo, descriptorName,
3183                                  "forgetGateBias", "cellBias");
3184     ValidateTensorDataTypesMatch(forgetGateBiasInfo, outputGateBiasInfo, descriptorName,
3185                                  "forgetGateBias", "outputGateBias");
3186 
3187     // Validate OPTIONAL params: CIFG (inputToInputWeights, recurrentToInputWeights, inputGateBias)
3188     const bool allCifgParamsPresentOrNot = ((m_InputToInputWeights && m_RecurrentToInputWeights && m_InputGateBias &&
3189                                              !m_Parameters.m_CifgEnabled) ||
3190                                             (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3191                                              !m_InputGateBias && m_Parameters.m_CifgEnabled));
3192 
3193     if (!allCifgParamsPresentOrNot)
3194     {
3195         throw InvalidArgumentException(descriptorName +
3196                 ": InputToInputWeights, RecurrentToInputWeights and InputGateBias must either all be present "
3197                 "(CIFG disabled) or not be present at all (CIFG enabled). m_Parameters.m_CifgEnabled should be "
3198                 "set appropriately.");
3199     }
3200 
3201     if (!m_Parameters.m_CifgEnabled)
3202     {
3203         // Validate number of dimensions and number of elements
3204         auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3205         ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (numUnits * inputSize), " InputToInputWeights");
3206 
3207         auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3208         ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (numUnits * outputSize),
3209                                     " RecurrentToInputWeights");
3210 
3211         auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3212         ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, numUnits, " InputGateBias");
3213 
3214         // Validate data types
3215         ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, inputToInputWeightsInfo, descriptorName,
3216                                      "inputToForgetWeights", "inputToInputWeights");
3217         ValidateTensorDataTypesMatch(inputToForgetWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3218                                      "inputToForgetWeights", "recurrentToInputWeights");
3219         ValidateTensorDataTypesMatch(forgetGateBiasInfo, inputGateBiasInfo, descriptorName,
3220                                      "forgetGateBias", "inputGateBias");
3221     }
3222 
3223     // Validate OPTIONAL params: Peephole (cellToInputWeights, cellToForgetWeights, cellToOutputWeights)
3224     bool allPeepholeWeightsPresentOrNot =
3225             (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights
3226               && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3227              || (!m_CellToInputWeights && !m_CellToForgetWeights
3228                  && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3229 
3230     if (!allPeepholeWeightsPresentOrNot)
3231     {
3232         throw InvalidArgumentException(descriptorName +
3233                 ": CellToInputWeights, CellToForgetWeights and CellToOutputWeights should all be present (Peephole "
3234                 "enabled) or not be present at all (Peephole disabled). CellToInputWeights should only be present "
3235                 "when Peephole is enabled and CIFG is disabled. m_Parameters.m_PeepholeEnabled should be set "
3236                 "appropriately.");
3237     }
3238 
3239     if (m_Parameters.m_PeepholeEnabled)
3240     {
3241         auto cellToForgetWeightsInfo = m_CellToForgetWeights->GetTensorInfo();
3242         ValidateTensorNumDimNumElem(cellToForgetWeightsInfo, 1, numUnits, " cellToForgetWeights");
3243         ValidateDataTypes(cellToForgetWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3244 
3245         auto cellToOutputWeightsInfo = m_CellToOutputWeights->GetTensorInfo();
3246         ValidateTensorNumDimNumElem(cellToOutputWeightsInfo, 1, numUnits, " cellToOutputWeights");
3247         ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToOutputWeightsInfo, descriptorName,
3248                                      "cellToForgetWeight", "cellToOutputWeights");
3249 
3250         if (!m_Parameters.m_CifgEnabled)
3251         {
3252             auto cellToInputWeightsInfo = m_CellToInputWeights->GetTensorInfo();
3253             ValidateTensorNumDimNumElem(cellToInputWeightsInfo, 1, numUnits, " cellToInputWeights");
3254             ValidateTensorDataTypesMatch(cellToForgetWeightsInfo, cellToInputWeightsInfo, descriptorName,
3255                                          "cellToForgetWeights", "cellToInputWeights");
3256         }
3257     }
3258 
3259     // Validate OPTIONAL params: Layer Norm Weights
3260     bool allLayerNormWeightsPresentOrNot =
3261             (((m_InputLayerNormWeights || m_Parameters.m_CifgEnabled) && m_ForgetLayerNormWeights
3262               && m_CellLayerNormWeights && m_OutputLayerNormWeights && m_Parameters.m_LayerNormEnabled)
3263              || (!m_InputLayerNormWeights && !m_ForgetLayerNormWeights && !m_CellLayerNormWeights
3264                  && !m_OutputLayerNormWeights && !m_Parameters.m_LayerNormEnabled));
3265 
3266     if (!allLayerNormWeightsPresentOrNot)
3267     {
3268         throw InvalidArgumentException(descriptorName +
3269                                        ": InputLayerNormWeights, ForgetLayerNormWeights, m_OutputLayerNormWeights "
3270                                        "and CellLayerNormWeights should all be present (Layer Norm enabled) or not "
3271                                        "be present at all (Layer Norm disabled). InputLayerNormWeights should "
3272                                        "only be present when Layer Norm is enabled and CIFG is disabled. "
3273                                        "m_Parameters.m_LayerNormEnabled should be set appropriately.");
3274     }
3275 
3276     if (m_Parameters.m_LayerNormEnabled)
3277     {
3278         auto forgetLayerNormWeightsInfo = m_ForgetLayerNormWeights->GetTensorInfo();
3279         ValidateTensorNumDimNumElem(forgetLayerNormWeightsInfo, 1, numUnits, " forgetLayerNormWeights");
3280         ValidateDataTypes(forgetLayerNormWeightsInfo, layerNormPeepholeWeightsSupportedTypes, descriptorName);
3281 
3282         auto cellLayerNormWeightsInfo = m_CellLayerNormWeights->GetTensorInfo();
3283         ValidateTensorNumDimNumElem(cellLayerNormWeightsInfo, 1, numUnits, " cellLayerNormWeights");
3284         ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, cellLayerNormWeightsInfo, descriptorName,
3285                                      "forgetLayerNormWeights", "cellLayerNormWeights");
3286 
3287         auto outputLayerNormWeightsInfo = m_OutputLayerNormWeights->GetTensorInfo();
3288         ValidateTensorNumDimNumElem(outputLayerNormWeightsInfo, 1, numUnits, " outputLayerNormWeights");
3289         ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, outputLayerNormWeightsInfo, descriptorName,
3290                                      "forgetLayerNormWeights", "outputLayerNormWeights");
3291 
3292         if (!m_Parameters.m_CifgEnabled)
3293         {
3294             auto inputLayerNormWeightsInfo = m_InputLayerNormWeights->GetTensorInfo();
3295             ValidateTensorNumDimNumElem(inputLayerNormWeightsInfo, 1, numUnits, " inputLayerNormWeights");
3296             ValidateTensorDataTypesMatch(forgetLayerNormWeightsInfo, inputLayerNormWeightsInfo, descriptorName,
3297                                          "forgetLayerNormWeights", "inputLayerNormWeights");
3298         }
3299     }
3300 
3301     // Validate OPTIONAL params: Projection (projectionWeights, projectionBias)
3302     bool correctProjectionTensorsPresent =
3303             ((!m_ProjectionWeights && !m_ProjectionBias && !m_Parameters.m_ProjectionEnabled) ||
3304             (m_ProjectionWeights && !m_ProjectionBias && m_Parameters.m_ProjectionEnabled) ||
3305             (m_ProjectionWeights && m_ProjectionBias && m_Parameters.m_ProjectionEnabled));
3306 
3307     if (!correctProjectionTensorsPresent)
3308     {
3309         throw InvalidArgumentException(descriptorName +
3310                                        ": If projection is enabled, ProjectionWeights should be present and "
3311                                        "ProjectionBias is optional. If projection is disabled, neither "
3312                                        "ProjectionWeights nor ProjectionBias should be present.");
3313     }
3314 
3315     if (m_Parameters.m_ProjectionEnabled)
3316     {
3317         auto projectionWeightsInfo = m_ProjectionWeights->GetTensorInfo();
3318         ValidateTensorNumDimNumElem(projectionWeightsInfo, 2, (numUnits * outputSize), "ProjectionWeights");
3319         ValidateDataTypes(projectionWeightsInfo, weightsSupportedTypes, descriptorName);
3320 
3321         if (m_ProjectionBias)
3322         {
3323             auto projectionBiasInfo = m_ProjectionBias->GetTensorInfo();
3324             ValidateTensorNumDimNumElem(projectionBiasInfo, 1, outputSize, "ProjectionBias");
3325             ValidateDataTypes(projectionBiasInfo, biasSupportedTypes, descriptorName);
3326         }
3327 
3328     }
3329     else if ((outputInfo.GetQuantizationScale() != m_Parameters.m_HiddenStateScale) &&
3330               outputInfo.GetQuantizationOffset() != m_Parameters.m_HiddenStateZeroPoint) {
3331         throw InvalidArgumentException(descriptorName +
3332                                        ": If projection is disabled, output quantization info (scale, offset) "
3333                                        "should match HiddenStateScale and HiddenStateZeroPoint.");
3334     }
3335 
3336 }
3337 
Validate(const WorkloadInfo & workloadInfo) const3338 void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3339 {
3340     const std::string descriptorName{"QuantizedLstmQueueDescriptor"};
3341 
3342     // Validate number of inputs/outputs
3343     ValidateNumInputs(workloadInfo,  descriptorName, 3);
3344     ValidateNumOutputs(workloadInfo, descriptorName, 2);
3345 
3346     // Input/output tensor infos
3347     auto inputInfo = workloadInfo.m_InputTensorInfos[0];
3348     auto cellStateInInfo = workloadInfo.m_InputTensorInfos[1];
3349     auto outputStateInInfo = workloadInfo.m_InputTensorInfos[2];
3350 
3351     auto cellStateOutInfo = workloadInfo.m_OutputTensorInfos[0];
3352     auto outputStateOutInfo = workloadInfo.m_OutputTensorInfos[1];
3353 
3354     std::vector<DataType> inputOutputSupportedTypes =
3355     {
3356         DataType::QAsymmU8
3357     };
3358 
3359     std::vector<DataType> cellStateSupportedTypes =
3360     {
3361         DataType::QSymmS16
3362     };
3363 
3364     std::vector<DataType> weightsSupportedTypes =
3365     {
3366         DataType::QAsymmU8
3367     };
3368 
3369     std::vector<DataType> biasSupportedTypes =
3370     {
3371         DataType::Signed32
3372     };
3373 
3374     // Validate types of input/output tensors
3375     ValidateDataTypes(inputInfo, inputOutputSupportedTypes, descriptorName);
3376     ValidateDataTypes(cellStateInInfo, cellStateSupportedTypes, descriptorName);
3377     ValidateDataTypes(outputStateInInfo, inputOutputSupportedTypes, descriptorName);
3378 
3379     ValidateDataTypes(cellStateOutInfo, cellStateSupportedTypes, descriptorName);
3380     ValidateDataTypes(outputStateOutInfo, inputOutputSupportedTypes, descriptorName);
3381 
3382     // Validate matching types of input/output tensors
3383     ValidateTensorDataTypesMatch(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3384     ValidateTensorDataTypesMatch(outputStateInInfo, outputStateOutInfo, descriptorName,
3385                                  "outputStateIn", "outputStateOut");
3386     ValidateTensorDataTypesMatch(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3387 
3388     // Validate matching quantization info for input/output tensors
3389     ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
3390     ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
3391     ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
3392 
3393     // Infer number of batches, input size and output size from tensor dimensions
3394     const uint32_t numBatches = inputInfo.GetShape()[0];
3395     const uint32_t inputSize  = inputInfo.GetShape()[1];
3396     const uint32_t outputSize = cellStateInInfo.GetShape()[1];
3397 
3398     // Validate number of dimensions and number of elements for input/output tensors
3399     ValidateTensorNumDimNumElem(inputInfo, 2, (numBatches * inputSize), descriptorName + " input");
3400     ValidateTensorNumDimNumElem(cellStateInInfo, 2, (numBatches * outputSize), descriptorName + " cellStateIn");
3401     ValidateTensorNumDimNumElem(outputStateInInfo, 2, (numBatches * outputSize), descriptorName + " outputStateIn");
3402     ValidateTensorNumDimNumElem(cellStateOutInfo, 2, (numBatches * outputSize), descriptorName + " cellStateOut");
3403     ValidateTensorNumDimNumElem(outputStateOutInfo, 2, (numBatches * outputSize), descriptorName + " outputStateOut");
3404 
3405     // Validate number of dimensions and number of elements for weights tensors
3406     ValidatePointer(m_InputToInputWeights, descriptorName, "InputToInputWeights");
3407     auto inputToInputWeightsInfo = m_InputToInputWeights->GetTensorInfo();
3408     ValidateTensorNumDimNumElem(inputToInputWeightsInfo, 2, (outputSize * inputSize), " InputToInputWeights");
3409 
3410     ValidatePointer(m_InputToForgetWeights, descriptorName, "InputToForgetWeights");
3411     auto inputToForgetWeightsInfo = m_InputToForgetWeights->GetTensorInfo();
3412     ValidateTensorNumDimNumElem(inputToForgetWeightsInfo, 2, (outputSize * inputSize), " InputToForgetWeights");
3413 
3414     ValidatePointer(m_InputToCellWeights, descriptorName, "InputToCellWeights");
3415     auto inputToCellWeightsInfo = m_InputToCellWeights->GetTensorInfo();
3416     ValidateTensorNumDimNumElem(inputToCellWeightsInfo, 2, (outputSize * inputSize), " InputToCellWeights");
3417 
3418     ValidatePointer(m_InputToOutputWeights, descriptorName, "InputToOutputWeights");
3419     auto inputToOutputWeightsInfo = m_InputToOutputWeights->GetTensorInfo();
3420     ValidateTensorNumDimNumElem(inputToOutputWeightsInfo, 2, (outputSize * inputSize), " InputToOutputWeights");
3421 
3422     ValidatePointer(m_RecurrentToInputWeights, descriptorName, "RecurrentToInputWeights");
3423     auto recurrentToInputWeightsInfo = m_RecurrentToInputWeights->GetTensorInfo();
3424     ValidateTensorNumDimNumElem(recurrentToInputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToInputWeights");
3425 
3426     ValidatePointer(m_RecurrentToForgetWeights, descriptorName, "RecurrentToForgetWeights");
3427     auto recurrentToForgetWeightsInfo = m_RecurrentToForgetWeights->GetTensorInfo();
3428     ValidateTensorNumDimNumElem(recurrentToForgetWeightsInfo, 2, (outputSize * outputSize),
3429                                 " RecurrentToForgetWeights");
3430 
3431     ValidatePointer(m_RecurrentToCellWeights, descriptorName, "RecurrentToCellWeights");
3432     auto recurrentToCellWeightsInfo = m_RecurrentToCellWeights->GetTensorInfo();
3433     ValidateTensorNumDimNumElem(recurrentToCellWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3434 
3435     ValidatePointer(m_RecurrentToOutputWeights, descriptorName, "RecurrentToOutputWeights");
3436     auto recurrentToOutputWeightsInfo = m_RecurrentToOutputWeights->GetTensorInfo();
3437     ValidateTensorNumDimNumElem(recurrentToOutputWeightsInfo, 2, (outputSize * outputSize), " RecurrentToCellWeights");
3438 
3439     // Validate data types for weights tensors (all should match each other)
3440     ValidateDataTypes(inputToInputWeightsInfo, weightsSupportedTypes, descriptorName);
3441 
3442     ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToForgetWeightsInfo, descriptorName,
3443                                  "inputToInputWeights", "inputToForgetWeights");
3444     ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToCellWeightsInfo, descriptorName,
3445                                  "inputToInputWeights", "inputToCellWeights");
3446     ValidateTensorDataTypesMatch(inputToInputWeightsInfo, inputToOutputWeightsInfo, descriptorName,
3447                                  "inputToInputWeights", "inputToOutputWeights");
3448 
3449     ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToInputWeightsInfo, descriptorName,
3450                                  "inputToInputWeights", "recurrentToInputWeights");
3451     ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToForgetWeightsInfo, descriptorName,
3452                                  "inputToInputWeights", "recurrentToForgeteights");
3453     ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToCellWeightsInfo, descriptorName,
3454                                  "inputToInputWeights", "recurrentToCellWeights");
3455     ValidateTensorDataTypesMatch(inputToInputWeightsInfo, recurrentToOutputWeightsInfo, descriptorName,
3456                                  "inputToInputWeights", "recurrentToOutputWeights");
3457 
3458     // Validate matching quantization info for weight tensors (all should match each other)
3459     ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToForgetWeightsInfo,
3460                                     descriptorName, "inputToInputWeights", "inputToForgetWeights");
3461     ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToCellWeightsInfo,
3462                                     descriptorName, "inputToInputWeights", "inputToCellWeights");
3463     ValidateTensorQuantizationSpace(inputToInputWeightsInfo, inputToOutputWeightsInfo,
3464                                     descriptorName, "inputToInputWeights", "inputToOutputWeights");
3465 
3466     ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToInputWeightsInfo,
3467                                     descriptorName, "inputToInputWeights", "recurrentToInputWeights");
3468     ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToForgetWeightsInfo,
3469                                     descriptorName, "inputToInputWeights", "recurrentToForgetWeights");
3470     ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToCellWeightsInfo,
3471                                     descriptorName, "inputToInputWeights", "recurrentToCellWeights");
3472     ValidateTensorQuantizationSpace(inputToInputWeightsInfo, recurrentToOutputWeightsInfo,
3473                                     descriptorName, "inputToInputWeights", "recurrentToOutputWeights");
3474 
3475     // Validate number of dimensions and number of elements in bias tensors
3476     ValidatePointer(m_InputGateBias, descriptorName, "InputGateBias");
3477     auto inputGateBiasInfo = m_InputGateBias->GetTensorInfo();
3478     ValidateTensorNumDimNumElem(inputGateBiasInfo, 1, outputSize, " InputGateBias");
3479 
3480     ValidatePointer(m_ForgetGateBias, descriptorName, "ForgetGateBias");
3481     auto forgetGateBiasInfo = m_ForgetGateBias->GetTensorInfo();
3482     ValidateTensorNumDimNumElem(forgetGateBiasInfo, 1, outputSize, " ForgetGateBias");
3483 
3484     ValidatePointer(m_CellBias, descriptorName, "CellBias");
3485     auto cellBiasInfo = m_CellBias->GetTensorInfo();
3486     ValidateTensorNumDimNumElem(cellBiasInfo, 1, outputSize, " CellBias");
3487 
3488     ValidatePointer(m_OutputGateBias, descriptorName, "OutputGateBias");
3489     auto outputGateBiasInfo = m_OutputGateBias->GetTensorInfo();
3490     ValidateTensorNumDimNumElem(outputGateBiasInfo, 1, outputSize, " OutputGateBias");
3491 
3492     // Validate data types for bias tensors (all should match each other)
3493     ValidateDataTypes(inputGateBiasInfo, biasSupportedTypes, descriptorName);
3494 
3495     ValidateTensorDataTypesMatch(inputGateBiasInfo, forgetGateBiasInfo, descriptorName,
3496                                  "inputGateBias", "forgetGateBias");
3497     ValidateTensorDataTypesMatch(inputGateBiasInfo, cellBiasInfo, descriptorName,
3498                                  "inputGateBias", "cellBias");
3499     ValidateTensorDataTypesMatch(inputGateBiasInfo, outputGateBiasInfo, descriptorName,
3500                                  "inputGateBias", "outputGateBias");
3501 
3502     // Validate bias tensor quantization info
3503     ValidateBiasTensorQuantization(inputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3504     ValidateBiasTensorQuantization(forgetGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3505     ValidateBiasTensorQuantization(cellBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3506     ValidateBiasTensorQuantization(outputGateBiasInfo, inputInfo, inputToInputWeightsInfo, descriptorName);
3507 }
3508 
Validate(const WorkloadInfo & workloadInfo) const3509 void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3510 {
3511     const std::string descriptorName{"AbsQueueDescriptor"};
3512 
3513     ValidateNumInputs(workloadInfo,  descriptorName, 1);
3514     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3515 
3516     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
3517     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3518 
3519     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3520 
3521     std::vector<DataType> supportedTypes =
3522     {
3523         DataType::BFloat16,
3524         DataType::Float16,
3525         DataType::Float32,
3526         DataType::QAsymmS8,
3527         DataType::QAsymmU8,
3528         DataType::QSymmS16,
3529         DataType::Signed32
3530     };
3531 
3532     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3533     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3534 }
3535 
Validate(const WorkloadInfo & workloadInfo) const3536 void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3537 {
3538     const std::string descriptorName{"SliceQueueDescriptor"};
3539 
3540     ValidateNumInputs(workloadInfo,  descriptorName, 1);
3541     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3542 
3543     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
3544     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3545 
3546     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3547 
3548     const unsigned int rank = inputTensorInfo.GetNumDimensions();
3549     if (rank > 4)
3550     {
3551         throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
3552     }
3553 
3554     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
3555 
3556     // Check if m_Begin and m_Size have the expected length
3557     if (m_Parameters.m_Begin.size() != rank)
3558     {
3559         throw InvalidArgumentException(descriptorName +
3560             ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
3561     }
3562     if (m_Parameters.m_Size.size() != rank)
3563     {
3564         throw InvalidArgumentException(descriptorName +
3565             ": Length of size descriptor must equal rank " + std::to_string(rank));
3566     }
3567 
3568     // Check if the shape of the output tensor matches m_Size
3569     const TensorShape& outputShape = outputTensorInfo.GetShape();
3570     for (unsigned int i = 0u; i < rank; ++i)
3571     {
3572         if (m_Parameters.m_Size[i] != outputShape[i])
3573         {
3574             throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
3575         }
3576     }
3577 
3578     // Check if the sum of begin offset and size in a given dimension
3579     // does not exceed the size of corresponding input
3580     const TensorShape& inputShape  = inputTensorInfo.GetShape();
3581     for(unsigned int i = 0u; i < rank; ++i)
3582     {
3583         if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] > inputShape[i])
3584         {
3585             throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
3586                 std::to_string(i) + " exceeds input size.");
3587         }
3588     }
3589 }
3590 
Validate(const WorkloadInfo & workloadInfo) const3591 void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3592 {
3593     const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
3594 
3595     ValidateNumInputs(workloadInfo,  descriptorName, 1);
3596     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3597 
3598     const TensorInfo& inputInfo  = workloadInfo.m_InputTensorInfos[0];
3599     const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
3600 
3601     ValidateTensorNumDimensions(inputInfo,  descriptorName, 4, "input");
3602     ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
3603 
3604     std::vector<DataType> supportedTypes =
3605     {
3606         DataType::BFloat16,
3607         DataType::Float32,
3608         DataType::Float16,
3609         DataType::QAsymmS8,
3610         DataType::QAsymmU8,
3611         DataType::QSymmS16
3612     };
3613 
3614     ValidateDataTypes(inputInfo,  supportedTypes, descriptorName);
3615     ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
3616 
3617     ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
3618 
3619     if (m_Parameters.m_BlockSize == 0)
3620     {
3621         throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
3622     }
3623 
3624     DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
3625     const unsigned int wIndex = dimensionIndices.GetWidthIndex();
3626     const unsigned int hIndex = dimensionIndices.GetHeightIndex();
3627     const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
3628 
3629     const TensorShape& outputShape = outputInfo.GetShape();
3630     if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex]  % m_Parameters.m_BlockSize != 0)
3631     {
3632         throw InvalidArgumentException(descriptorName + ": Output width and height shape"
3633                                        "must be divisible by block size.");
3634     }
3635 
3636     const TensorShape& inputShape = inputInfo.GetShape();
3637     if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
3638     {
3639         throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
3640                                        "must be divisible by the square of block size." );
3641     }
3642 }
3643 
Validate(const WorkloadInfo & workloadInfo) const3644 void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3645 {
3646     const std::string descriptorName{"ComparisonQueueDescriptor"};
3647 
3648     ValidateNumInputs(workloadInfo,  descriptorName, 2);
3649     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3650 
3651     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3652     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3653     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3654 
3655     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3656                                        inputTensorInfo1,
3657                                        outputTensorInfo,
3658                                        descriptorName,
3659                                        "input_0",
3660                                        "input_1");
3661 
3662     if (outputTensorInfo.GetDataType() != DataType::Boolean)
3663     {
3664         throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3665     }
3666 }
3667 
Validate(const WorkloadInfo & workloadInfo) const3668 void ElementwiseBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3669 {
3670     const std::string descriptorName{"ElementwiseBinaryQueueDescriptor"};
3671 
3672     ValidateNumInputs(workloadInfo,  descriptorName, 2);
3673     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3674 
3675     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3676     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3677     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3678 
3679     std::vector<DataType> supportedTypes =
3680             {
3681                     DataType::BFloat16,
3682                     DataType::Float16,
3683                     DataType::Float32,
3684                     DataType::QAsymmS8,
3685                     DataType::QAsymmU8,
3686                     DataType::QSymmS16,
3687                     DataType::Signed32
3688             };
3689 
3690     ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName);
3691     ValidateDataTypes(inputTensorInfo1, supportedTypes, descriptorName);
3692 
3693     ValidateTensorDataTypesMatch(inputTensorInfo0, outputTensorInfo, descriptorName, "input", "output");
3694     ValidateTensorDataTypesMatch(inputTensorInfo1, outputTensorInfo, descriptorName, "input", "output");
3695 }
3696 
Validate(const WorkloadInfo & workloadInfo) const3697 void ElementwiseUnaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3698 {
3699     const std::string descriptorName{"ElementwiseUnaryQueueDescriptor"};
3700 
3701     ValidateNumInputs(workloadInfo,  descriptorName, 1);
3702     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3703 
3704     const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
3705     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3706 
3707     ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3708 
3709     std::vector<DataType> supportedTypes =
3710     {
3711         DataType::BFloat16,
3712         DataType::Float16,
3713         DataType::Float32,
3714         DataType::QAsymmS8,
3715         DataType::QAsymmU8,
3716         DataType::QSymmS16,
3717         DataType::Signed32
3718     };
3719 
3720     std::vector<DataType> logicalSupportedTypes =
3721     {
3722         DataType::Boolean
3723     };
3724 
3725     if (m_Parameters.m_Operation == UnaryOperation::LogicalNot)
3726     {
3727         ValidateDataTypes(inputTensorInfo, logicalSupportedTypes, descriptorName);
3728     }
3729     else
3730     {
3731         ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3732     }
3733 
3734 
3735     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3736 }
3737 
Validate(const WorkloadInfo & workloadInfo) const3738 void RankQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3739 {
3740     const std::string descriptorName{"RankQueueDescriptor"};
3741 
3742     ValidateNumInputs(workloadInfo,  descriptorName, 1);
3743     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3744 
3745     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
3746     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3747 
3748     ValidateTensorNumDimensions(outputTensorInfo, descriptorName, 1, "output");
3749     ValidateTensorNumElements(outputTensorInfo, descriptorName, 1, "output");
3750 
3751     std::vector<DataType> supportedTypes =
3752     {
3753         DataType::BFloat16,
3754         DataType::Float16,
3755         DataType::Float32,
3756         DataType::QAsymmS8,
3757         DataType::QAsymmU8,
3758         DataType::QSymmS8,
3759         DataType::QSymmS16,
3760         DataType::Signed32
3761     };
3762 
3763     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3764     ValidateDataTypes(outputTensorInfo, { DataType::Signed32 }, descriptorName);
3765 }
3766 
Validate(const WorkloadInfo & workloadInfo) const3767 void LogicalBinaryQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3768 {
3769     const std::string descriptorName{"LogicalBinaryQueueDescriptor"};
3770 
3771     ValidateNumInputs(workloadInfo,  descriptorName, 2);
3772     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3773 
3774     const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
3775     const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
3776     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3777 
3778     ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
3779                                        inputTensorInfo1,
3780                                        outputTensorInfo,
3781                                        descriptorName,
3782                                        "input_0",
3783                                        "input_1");
3784 
3785     if (inputTensorInfo0.GetDataType() != DataType::Boolean)
3786     {
3787         throw InvalidArgumentException(descriptorName + ": Input tensor 0 type must be Boolean.");
3788     }
3789 
3790     if (inputTensorInfo1.GetDataType() != DataType::Boolean)
3791     {
3792         throw InvalidArgumentException(descriptorName + ": Input tensor 1 type must be Boolean.");
3793     }
3794 
3795     if (outputTensorInfo.GetDataType() != DataType::Boolean)
3796     {
3797         throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
3798     }
3799 }
3800 
Validate(const WorkloadInfo & workloadInfo) const3801 void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3802 {
3803     const std::string descriptorName{"ReduceQueueDescriptor"};
3804 
3805     ValidateNumInputs(workloadInfo,  descriptorName, 1);
3806     ValidateNumOutputs(workloadInfo, descriptorName, 1);
3807 
3808     const TensorInfo& inputTensorInfo  = workloadInfo.m_InputTensorInfos[0];
3809     const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
3810 
3811     std::vector<DataType> supportedTypes =
3812     {
3813         DataType::BFloat16,
3814         DataType::Float16,
3815         DataType::Float32,
3816         DataType::QAsymmS8,
3817         DataType::QAsymmU8,
3818         DataType::QSymmS16,
3819         DataType::Signed32
3820     };
3821 
3822     ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
3823     ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
3824 }
3825 
Validate(const WorkloadInfo & workloadInfo) const3826 void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
3827 {
3828     // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm
3829 
3830     const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"};
3831 
3832     // check dimensions of all inputs and outputs
3833     if (workloadInfo.m_InputTensorInfos.size() != 3)
3834     {
3835         throw InvalidArgumentException(descriptorName + ": Invalid number of inputs.");
3836     }
3837     if (workloadInfo.m_OutputTensorInfos.size() != 3)
3838     {
3839         throw InvalidArgumentException(descriptorName + ": Invalid number of outputs.");
3840     }
3841 
3842     std::vector<DataType> supportedTypes =
3843     {
3844         DataType::Float32,
3845         DataType::QAsymmS8
3846     };
3847 
3848     // check for supported type of one input and match them with all the other input and output
3849     ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName);
3850 
3851     // Making sure clipping parameters have valid values.
3852     // == 0 means no clipping
3853     //  > 0 means clipping
3854     if (m_Parameters.m_ClippingThresCell < 0.0f)
3855     {
3856         throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
3857     }
3858     if (m_Parameters.m_ClippingThresProj < 0.0f)
3859     {
3860         throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
3861     }
3862 
3863     unsigned int batchIndx = 0;
3864     unsigned int inputIndx = 1;
3865     uint32_t timeStep = 1;
3866     unsigned int timeIndx = 1;
3867     inputIndx = 2;
3868     if (m_Parameters.m_TimeMajor)
3869     {
3870         batchIndx = 1;
3871         timeIndx = 0;
3872 
3873     }
3874     timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx];
3875 
3876     // Inferring batch size, number of outputs and number of cells from the inputs.
3877     const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx];
3878     const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx];
3879     ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
3880     const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0];
3881     ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights");
3882     const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1];
3883 
3884     // input tensor
3885     ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input),
3886                                 descriptorName + " input_0");
3887     // outputStateInTensor
3888     ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output),
3889                                 descriptorName + " input_1");
3890     // outputStateInTensor
3891     ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell),
3892                                 descriptorName + " input_2");
3893 
3894     // outputTensor
3895     ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output),
3896                                 descriptorName + " output_0");
3897 
3898     // check that dimensions of inputs/outputs and QueueDescriptor data match with each other
3899     if ( m_InputToInputWeights )
3900     {
3901         ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2,
3902                                       (n_cell * n_input), "InputLayerNormWeights");
3903     }
3904 
3905     ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights");
3906     ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2,
3907                                   (n_cell * n_input), "InputToForgetWeights");
3908 
3909     ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights");
3910     ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2,
3911                                   (n_cell * n_input), "InputToCellWeights");
3912 
3913     if ( m_RecurrentToInputWeights )
3914     {
3915         ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2,
3916                                       (n_cell * n_output), "RecurrentToInputWeights");
3917     }
3918 
3919     ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights");
3920     ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2,
3921                                   (n_cell * n_output), "RecurrentToForgetWeights");
3922 
3923     ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights");
3924     ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2,
3925                                   (n_cell * n_output), "RecurrentToCellWeights");
3926 
3927     // Make sure the input-gate's parameters are either both present (regular
3928     // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly.
3929     bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights &&
3930                                      !m_Parameters.m_CifgEnabled) ||
3931                                      (!m_InputToInputWeights && !m_RecurrentToInputWeights &&
3932                                      m_Parameters.m_CifgEnabled));
3933     if (!cifg_weights_all_or_none)
3934     {
3935         throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and "
3936                                        "RecurrentToInputWeights must either both be present (regular LSTM) "
3937                                        "or both not present (CIFG-LSTM). In addition CifgEnable must be set "
3938                                        "accordingly.");
3939     }
3940 
3941     if ( m_CellToInputWeights )
3942     {
3943         ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1,
3944                                       n_cell, "CellToInputWeights");
3945     }
3946     if ( m_CellToForgetWeights )
3947     {
3948         ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1,
3949                                       n_cell, "CellToForgetWeights");
3950     }
3951     if ( m_CellToOutputWeights )
3952     {
3953         ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1,
3954                                       n_cell, "CellToOutputWeights");
3955     }
3956 
3957     // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly.
3958     bool peephole_weights_all_or_none =
3959             (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) &&  m_CellToForgetWeights
3960             && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled)
3961             || ( !m_CellToInputWeights && !m_CellToForgetWeights
3962             && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled));
3963     if (!peephole_weights_all_or_none)
3964     {
3965         throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters.");
3966     }
3967 
3968     // Make sure the input gate bias is present only when not a CIFG-LSTM.
3969     if (m_Parameters.m_CifgEnabled)
3970     {
3971         if (m_InputGateBias)
3972         {
3973             throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled.");
3974         }
3975     }
3976     else
3977     {
3978         if (!m_InputGateBias)
3979         {
3980             throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias "
3981                                            "must be present.");
3982         }
3983         ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1,
3984                                       n_cell, "InputGateBias");
3985     }
3986 
3987     ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias");
3988     ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias");
3989 
3990     ValidatePointer(m_CellBias, "Null pointer check", "CellBias");
3991     ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias");
3992 
3993     ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias");
3994     ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias");
3995 
3996     if (m_ProjectionWeights)
3997     {
3998         ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2,
3999                                       (n_cell * n_output), "ProjectionWeights");
4000     }
4001     if (m_ProjectionBias)
4002     {
4003         ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias");
4004     }
4005 
4006     // Making sure the projection tensors are consistent:
4007     // 1) If projection weight is not present, then projection bias should not be
4008     // present.
4009     // 2) If projection weight is present, then projection bias is optional.
4010     bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias &&
4011                                         !m_Parameters.m_ProjectionEnabled)
4012                                         || (m_ProjectionWeights && !m_ProjectionBias &&
4013                                         m_Parameters.m_ProjectionEnabled)
4014                                         || (m_ProjectionWeights && m_ProjectionBias &&
4015                                         m_Parameters.m_ProjectionEnabled));
4016     if (!projecton_tensors_consistent)
4017     {
4018         throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent.");
4019     }
4020 
4021     // The four layer normalization weights either all have values or none of them have values. Additionally, if
4022     // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights
4023     // either all have values or none of them have values. Layer normalization is used when the values of all the
4024     // layer normalization weights are present
4025     if (m_InputLayerNormWeights)
4026     {
4027         ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights");
4028     }
4029     if (m_ForgetLayerNormWeights)
4030     {
4031         ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4032     }
4033     if (m_CellLayerNormWeights)
4034     {
4035         ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4036     }
4037     if (m_OutputLayerNormWeights)
4038     {
4039         ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4040     }
4041 
4042     if (m_Parameters.m_LayerNormEnabled)
4043     {
4044         if (!m_Parameters.m_CifgEnabled)
4045         {
4046             if (!m_InputLayerNormWeights)
4047             {
4048                 throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is "
4049                                                "disabled but InputLayerNormWeights are not present");
4050             }
4051             ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(),
4052                                           1, n_cell, "InputLayerNormWeights");
4053         }
4054         else if (m_InputLayerNormWeights)
4055         {
4056             throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is "
4057                                            "enabled");
4058         }
4059 
4060         ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled",
4061                         "ForgetLayerNormWeights");
4062         ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights");
4063 
4064         ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled",
4065                         "OutputLayerNormWeights");
4066         ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights");
4067 
4068         ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled",
4069                         "CellLayerNormWeights");
4070         ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights");
4071     }
4072     else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights)
4073     {
4074         throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer "
4075                                        "normalisation weights are present.");
4076     }
4077 }
4078 
Validate(const WorkloadInfo & workloadInfo) const4079 void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
4080 {
4081     const std::string descriptorName{"BatchMatMulDescriptor"};
4082 
4083     ValidateNumInputs(workloadInfo,  descriptorName, 2);
4084     ValidateNumOutputs(workloadInfo, descriptorName, 1);
4085 
4086     // Inputs must be: both 2D+
4087     // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
4088     // axes N and I must be the same size
4089 
4090     const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
4091     const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
4092     const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
4093     // Output info has already been inferred
4094 
4095     std::vector<DataType> supportedTypes =
4096     {
4097         DataType::BFloat16,
4098         DataType::Float16,
4099         DataType::Float32,
4100         DataType::QAsymmS8,
4101         DataType::QAsymmU8,
4102         DataType::QSymmS16
4103     };
4104 
4105     ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
4106     ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
4107     ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
4108 
4109     if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
4110         (inputYInfoBeforeParams.GetNumDimensions() < 2))
4111     {
4112         throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
4113     }
4114 
4115     TensorInfo inputXInfoAfterParams;
4116     TensorInfo inputYInfoAfterParams;
4117 
4118     if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
4119        (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
4120     {
4121         throw InvalidArgumentException(descriptorName +
4122             ": Invalid descriptor parameters - Transpose and Adjoint "
4123             "cannot both be true for a given input tensor.");
4124     }
4125     if(m_Parameters.m_TransposeX)
4126     {
4127         inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
4128                                                      BatchMatMulDescriptor::GetPermuteVec(
4129                                                          m_Parameters.m_DataLayoutX,
4130                                                          inputXInfoBeforeParams.GetShape()));
4131     }
4132     else if(m_Parameters.m_AdjointX)
4133     {
4134         auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4135                                                              inputXInfoBeforeParams.GetShape());
4136         if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
4137            inputXInfoBeforeParams.GetShape()[axesToMul.second])
4138         {
4139             throw InvalidArgumentException(descriptorName +
4140                 ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
4141         }
4142         // Shape remains the same as it's square
4143         inputXInfoAfterParams = inputXInfoBeforeParams;
4144     }
4145     else
4146     {
4147         inputXInfoAfterParams = inputXInfoBeforeParams;
4148     }
4149 
4150     if(m_Parameters.m_TransposeY)
4151     {
4152         inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
4153                                                      BatchMatMulDescriptor::GetPermuteVec(
4154                                                          m_Parameters.m_DataLayoutY,
4155                                                          inputYInfoBeforeParams.GetShape()));
4156     }
4157     else if(m_Parameters.m_AdjointY)
4158     {
4159         auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4160                                                              inputYInfoBeforeParams.GetShape());
4161         if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
4162            inputYInfoBeforeParams.GetShape()[axesToMul.second])
4163         {
4164             throw InvalidArgumentException(descriptorName +
4165                 ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
4166         }
4167         // Shape remains the same as it's square
4168         inputYInfoAfterParams = inputYInfoBeforeParams;
4169     }
4170     else
4171     {
4172         inputYInfoAfterParams = inputYInfoBeforeParams;
4173     }
4174 
4175     switch(m_Parameters.m_DataLayoutX)
4176     {
4177         case DataLayout::NCDHW:
4178         case DataLayout::NDHWC:
4179             if(inputXInfoAfterParams.GetNumDimensions() < 3)
4180             {
4181                 throw InvalidArgumentException(descriptorName +
4182                     ": Input tensor X does not have the correct "
4183                     "number of dimensions for the Data Layout that it has been assigned.");
4184             }
4185             break;
4186         case DataLayout::NCHW:
4187         case DataLayout::NHWC:
4188         default:
4189             break;
4190     }
4191 
4192     switch(m_Parameters.m_DataLayoutY)
4193     {
4194         case DataLayout::NCDHW:
4195         case DataLayout::NDHWC:
4196             if(inputYInfoAfterParams.GetNumDimensions() < 3)
4197             {
4198                 throw InvalidArgumentException(descriptorName +
4199                     ": Input tensor Y does not have the correct "
4200                     "number of dimensions for the Data Layout that it has been assigned.");
4201             }
4202             break;
4203         case DataLayout::NCHW:
4204         case DataLayout::NHWC:
4205         default:
4206             break;
4207     }
4208 
4209     auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
4210         inputXInfoAfterParams.GetShape());
4211     auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
4212         inputXInfoBeforeParams.GetShape());
4213 
4214     if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
4215        != inputYInfoAfterParams.GetShape()[axesYToMul.first])
4216     {
4217         throw InvalidArgumentException(descriptorName +
4218             ": The final axis of input tensor X must be the same size as "
4219             "the second last axis of input tensor Y.");
4220     }
4221 
4222     {   // Separate scope so we don't pollute the rest of the scope with our temp variables
4223         // e.g. NHWC isnt compatible with NCHW as of now
4224         DataLayout xLayout = m_Parameters.m_DataLayoutX;
4225         DataLayout yLayout = m_Parameters.m_DataLayoutY;
4226 
4227         if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
4228         {
4229             if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
4230             {
4231                 throw InvalidArgumentException(descriptorName +
4232                     ": Invalid input tensor data layout combination.");
4233             }
4234         }
4235         if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
4236         {
4237             if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
4238             {
4239                 throw InvalidArgumentException(descriptorName +
4240                     ": Invalid input tensor data layout combination.");
4241             }
4242         }
4243     }
4244 
4245     // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
4246     unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
4247                                                 inputYInfoAfterParams.GetNumDimensions());
4248     if(outputTensorDimSize-2 > 0)
4249     {
4250         TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4251                                           DataType::Float32);
4252         TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4253                                           DataType::Float32);
4254         TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
4255                                             DataType::Float32);
4256 
4257         auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
4258         {
4259             auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
4260 
4261             for(unsigned int i = 0; i < sizeDiff; i++)
4262             {
4263                 axisIndices.insert(axisIndices.begin(), 1);
4264             }
4265 
4266             for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
4267             {
4268                 ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
4269             }
4270         };
4271 
4272         auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
4273                                                                 inputXInfoAfterParams.GetShape());
4274         auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
4275                                                                 inputYInfoAfterParams.GetShape());
4276 
4277         doAxisExtension(axesXNotMul, tiXNotMul);
4278         doAxisExtension(axesYNotMul, tiYNotMul);
4279 
4280         for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
4281         {
4282             tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
4283                                                  tiYNotMul.GetShape()[i]);
4284         }
4285 
4286         ValidateBroadcastTensorShapesMatch(tiXNotMul,
4287                                            tiYNotMul,
4288                                            tiOutNotMul,
4289                                            descriptorName,
4290                                            "input_X",
4291                                            "input_Y");
4292     }
4293 }
4294 
4295 
4296 } // namespace armnn